Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from PIL import Image | |
| import torchvision | |
| import random | |
| from torch.utils.data import Dataset, DataLoader | |
| from functools import partial | |
| from multiprocessing import cpu_count | |
| from datasets import load_dataset | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| class PNGDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_dir, | |
| tokenizer, | |
| from_hf_hub=False, | |
| ucg=0.10, | |
| resolution=(512, 512), | |
| prompt_key="tags", | |
| cond_key="cond", | |
| target_key="image", | |
| controlnet_hint_key=None, | |
| file_extension="png", | |
| ): | |
| super().__init__() | |
| vars(self).update(locals()) | |
| if from_hf_hub: | |
| self.img_paths = load_dataset(data_dir)["train"] | |
| else: | |
| self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}")) | |
| self.ucg = ucg | |
| self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5) | |
| self.transforms = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.Resize(resolution), | |
| torchvision.transforms.ToTensor(), | |
| ] | |
| ) | |
| self.normalize = torchvision.transforms.Normalize([0.5], [0.5]) | |
| def process_canny(self, image): | |
| # code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet | |
| image = np.array(image) | |
| low_threshold, high_threshold = (100, 200) | |
| image = cv2.Canny(image, low_threshold, high_threshold) | |
| image = image[:, :, None] | |
| image = np.concatenate([image, image, image], axis=2) | |
| canny_image = Image.fromarray(image) | |
| return canny_image | |
| def __len__(self): | |
| return len(self.img_paths) | |
| def __getitem__(self, idx): | |
| if self.from_hf_hub: | |
| image = self.img_paths[idx]["image"] | |
| else: | |
| image = Image.open(self.img_paths[idx]) | |
| if self.prompt_key not in image.info: | |
| print(f"Image {idx} lacks {self.prompt_key}, skipping to next image") | |
| return self.__getitem__(idx + 1 % len(self)) | |
| if random.random() < self.ucg: | |
| tags = "" | |
| else: | |
| tags = image.info[self.prompt_key] | |
| # randomly flip image here so input image to canny has matching flip | |
| image = self.flip_transform(image) | |
| target = self.normalize(self.transforms(image)) | |
| output_dict = {self.target_key: target, self.cond_key: tags} | |
| if self.controlnet_hint_key == "canny": | |
| canny_image = self.transforms(self.process_canny(image)) | |
| output_dict[self.controlnet_hint_key] = canny_image | |
| return output_dict | |
| def collate_fn(self, samples): | |
| prompts = torch.tensor( | |
| [ | |
| self.tokenizer( | |
| sample[self.cond_key], | |
| padding="max_length", | |
| truncation=True, | |
| ).input_ids | |
| for sample in samples | |
| ] | |
| ) | |
| images = torch.stack( | |
| [sample[self.target_key] for sample in samples] | |
| ).contiguous() | |
| batch = { | |
| self.cond_key: prompts, | |
| self.target_key: images, | |
| } | |
| if self.controlnet_hint_key is not None: | |
| hint = torch.stack( | |
| [sample[self.controlnet_hint_key] for sample in samples] | |
| ).contiguous() | |
| batch[self.controlnet_hint_key] = hint | |
| return batch | |
| class PNGDataModule: | |
| def __init__( | |
| self, | |
| batch_size=1, | |
| num_workers=None, | |
| persistent_workers=True, | |
| **kwargs, # passed to dataset class | |
| ): | |
| super().__init__() | |
| vars(self).update(locals()) | |
| if num_workers is None: | |
| num_workers = cpu_count() // 2 | |
| self.ds_wrapper = partial(PNGDataset, **kwargs) | |
| self.dl_wrapper = partial( | |
| DataLoader, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| persistent_workers=persistent_workers, | |
| ) | |
| def get_dataloader(self, data_dir, shuffle=False): | |
| dataset = self.ds_wrapper(data_dir=data_dir) | |
| dataloader = self.dl_wrapper( | |
| dataset, shuffle=shuffle, collate_fn=dataset.collate_fn | |
| ) | |
| return dataloader | |