Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class LatentRebatch: | |
| def INPUT_TYPES(s): | |
| return {"required": { "latents": ("LATENT",), | |
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
| }} | |
| RETURN_TYPES = ("LATENT",) | |
| INPUT_IS_LIST = True | |
| OUTPUT_IS_LIST = (True, ) | |
| FUNCTION = "rebatch" | |
| CATEGORY = "latent/batch" | |
| def get_batch(latents, list_ind, offset): | |
| '''prepare a batch out of the list of latents''' | |
| samples = latents[list_ind]['samples'] | |
| shape = samples.shape | |
| mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') | |
| if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: | |
| torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") | |
| if mask.shape[0] < samples.shape[0]: | |
| mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] | |
| if 'batch_index' in latents[list_ind]: | |
| batch_inds = latents[list_ind]['batch_index'] | |
| else: | |
| batch_inds = [x+offset for x in range(shape[0])] | |
| return samples, mask, batch_inds | |
| def get_slices(indexable, num, batch_size): | |
| '''divides an indexable object into num slices of length batch_size, and a remainder''' | |
| slices = [] | |
| for i in range(num): | |
| slices.append(indexable[i*batch_size:(i+1)*batch_size]) | |
| if num * batch_size < len(indexable): | |
| return slices, indexable[num * batch_size:] | |
| else: | |
| return slices, None | |
| def slice_batch(batch, num, batch_size): | |
| result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] | |
| return list(zip(*result)) | |
| def cat_batch(batch1, batch2): | |
| if batch1[0] is None: | |
| return batch2 | |
| result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] | |
| return result | |
| def rebatch(self, latents, batch_size): | |
| batch_size = batch_size[0] | |
| output_list = [] | |
| current_batch = (None, None, None) | |
| processed = 0 | |
| for i in range(len(latents)): | |
| # fetch new entry of list | |
| #samples, masks, indices = self.get_batch(latents, i) | |
| next_batch = self.get_batch(latents, i, processed) | |
| processed += len(next_batch[2]) | |
| # set to current if current is None | |
| if current_batch[0] is None: | |
| current_batch = next_batch | |
| # add previous to list if dimensions do not match | |
| elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: | |
| sliced, _ = self.slice_batch(current_batch, 1, batch_size) | |
| output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) | |
| current_batch = next_batch | |
| # cat if everything checks out | |
| else: | |
| current_batch = self.cat_batch(current_batch, next_batch) | |
| # add to list if dimensions gone above target batch size | |
| if current_batch[0].shape[0] > batch_size: | |
| num = current_batch[0].shape[0] // batch_size | |
| sliced, remainder = self.slice_batch(current_batch, num, batch_size) | |
| for i in range(num): | |
| output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) | |
| current_batch = remainder | |
| #add remainder | |
| if current_batch[0] is not None: | |
| sliced, _ = self.slice_batch(current_batch, 1, batch_size) | |
| output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) | |
| #get rid of empty masks | |
| for s in output_list: | |
| if s['noise_mask'].mean() == 1.0: | |
| del s['noise_mask'] | |
| return (output_list,) | |
| class ImageRebatch: | |
| def INPUT_TYPES(s): | |
| return {"required": { "images": ("IMAGE",), | |
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), | |
| }} | |
| RETURN_TYPES = ("IMAGE",) | |
| INPUT_IS_LIST = True | |
| OUTPUT_IS_LIST = (True, ) | |
| FUNCTION = "rebatch" | |
| CATEGORY = "image/batch" | |
| def rebatch(self, images, batch_size): | |
| batch_size = batch_size[0] | |
| output_list = [] | |
| all_images = [] | |
| for img in images: | |
| for i in range(img.shape[0]): | |
| all_images.append(img[i:i+1]) | |
| for i in range(0, len(all_images), batch_size): | |
| output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) | |
| return (output_list,) | |
| NODE_CLASS_MAPPINGS = { | |
| "RebatchLatents": LatentRebatch, | |
| "RebatchImages": ImageRebatch, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "RebatchLatents": "Rebatch Latents", | |
| "RebatchImages": "Rebatch Images", | |
| } | |