Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # -------------------------------------------------------- | |
| # InstructDiffusion | |
| # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix) | |
| # Modified by Binxin Yang ([email protected]) | |
| # -------------------------------------------------------- | |
| from __future__ import annotations | |
| import os | |
| import random | |
| import copy | |
| import json | |
| import math | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from einops import rearrange | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from dataset.seg.refcoco import REFER | |
| class RefCOCODataset(Dataset): | |
| def __init__( | |
| self, | |
| path: str, | |
| split: str = "train", | |
| min_resize_res: int = 256, | |
| max_resize_res: int = 256, | |
| crop_res: int = 256, | |
| flip_prob: float = 0.0, | |
| transparency: float = 0.0, | |
| test: bool = False, | |
| ): | |
| assert split in ("train", "val", "test") | |
| self.path = path | |
| self.min_resize_res = min_resize_res | |
| self.max_resize_res = max_resize_res | |
| self.crop_res = crop_res | |
| self.flip_prob = flip_prob | |
| self.G_ref_dataset=REFER(data_root=path) | |
| self.IMAGE_DIR = os.path.join(path, 'images/train2014') | |
| self.list_ref=self.G_ref_dataset.getRefIds(split=split) | |
| self.transparency = transparency | |
| self.test = test | |
| seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt' | |
| self.seg_diverse_prompt_list=[] | |
| with open(seg_diverse_prompt_path) as f: | |
| line=f.readline() | |
| while line: | |
| line=line.strip('\n') | |
| self.seg_diverse_prompt_list.append(line) | |
| line=f.readline() | |
| color_list_file_path='dataset/prompt/color_list_train_small.txt' | |
| self.color_list=[] | |
| with open(color_list_file_path) as f: | |
| line = f.readline() | |
| while line: | |
| line_split = line.strip('\n').split(" ") | |
| if len(line_split)>1: | |
| temp = [] | |
| for i in range(4): | |
| temp.append(line_split[i]) | |
| self.color_list.append(temp) | |
| line = f.readline() | |
| def __len__(self) -> int: | |
| return len(self.list_ref) | |
| def _augmentation_new(self, image, label): | |
| # Cropping | |
| h, w = label.shape | |
| if h > w: | |
| start_h = random.randint(0, h - w) | |
| end_h = start_h + w | |
| image = image[start_h:end_h] | |
| label = label[start_h:end_h] | |
| elif h < w: | |
| start_w = random.randint(0, w - h) | |
| end_w = start_w + h | |
| image = image[:, start_w:end_w] | |
| label = label[:, start_w:end_w] | |
| else: | |
| pass | |
| image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS) | |
| image = np.asarray(image, dtype=np.uint8) | |
| label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST) | |
| label = np.asarray(label, dtype=np.int64) | |
| return image, label | |
| def __getitem__(self, i: int) -> dict[str, Any]: | |
| ref_ids = self.list_ref[i] | |
| ref = self.G_ref_dataset.loadRefs(ref_ids)[0] | |
| sentences = random.choice(ref['sentences'])['sent'] | |
| prompt = random.choice(self.seg_diverse_prompt_list) | |
| color = random.choice(self.color_list) | |
| color_name = color[0] | |
| prompt = prompt.format(color=color_name.lower(), object=sentences.lower()) | |
| R, G, B = color[3].split(",") | |
| R = int(R) | |
| G = int(G) | |
| B = int(B) | |
| image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name'] | |
| image_path = os.path.join(self.IMAGE_DIR,image_name) | |
| mask = self.G_ref_dataset.getMask(ref=ref)['mask'] | |
| image = Image.open(image_path).convert("RGB") | |
| image = np.asarray(image) | |
| image, mask = self._augmentation_new(image,mask) | |
| mask = (mask == 1) | |
| image_0 = Image.fromarray(image) | |
| image_1 = copy.deepcopy(image) | |
| image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R | |
| image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G | |
| image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B | |
| image_1 = Image.fromarray(image_1) | |
| reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() | |
| image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
| image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) | |
| image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") | |
| image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") | |
| crop = torchvision.transforms.RandomCrop(self.crop_res) | |
| flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
| image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
| mask = torch.tensor(mask).float() | |
| crop = torchvision.transforms.RandomCrop(self.crop_res) | |
| flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) | |
| image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) | |
| return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) |