Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional | |
import os | |
import random | |
import yaml | |
import glob | |
from PIL import Image | |
import torch | |
from datasets import load_dataset, concatenate_datasets | |
from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor | |
class OmniGen2TestDataset(torch.utils.data.Dataset): | |
SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions." | |
def __init__( | |
self, | |
config_path: str, | |
tokenizer, | |
use_chat_template: bool, | |
max_pixels: Optional[int] = None, | |
max_side_length: Optional[int] = None, | |
img_scale_num: int = 16, | |
align_res: bool = True | |
): | |
self.max_pixels = max_pixels | |
self.max_side_length = max_side_length | |
self.img_scale_num = img_scale_num | |
self.align_res = align_res | |
with open(config_path, "r") as f: | |
self.config = yaml.load(f, Loader=yaml.FullLoader) | |
self.use_chat_template = use_chat_template | |
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True) | |
data = self._collect_annotations(self.config) | |
self.data = data | |
self.tokenizer = tokenizer | |
def _collect_annotations(self, config): | |
json_datasets = [] | |
for data in config['data']: | |
data_path, data_type = data['path'], data.get("type", "default") | |
if os.path.isdir(data_path): | |
jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True)) | |
json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train'] | |
else: | |
data_ext = os.path.splitext(data_path)[-1] | |
if data_ext in [".json", ".jsonl"]: | |
json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train'] | |
elif data_ext in [".yml", ".yaml"]: | |
with open(data_path, "r") as f: | |
sub_config = yaml.load(f, Loader=yaml.FullLoader) | |
json_dataset = self._collect_annotations(sub_config) | |
else: | |
raise NotImplementedError( | |
f'Unknown data file extension: "{data_ext}". ' | |
f"Currently, .json, .jsonl .yml .yaml are supported. " | |
"If you are using a supported format, please set the file extension so that the proper parsing " | |
"routine can be called." | |
) | |
json_datasets.append(json_dataset) | |
json_dataset = concatenate_datasets(json_datasets) | |
return json_dataset | |
def apply_chat_template(self, instruction, system_prompt): | |
if self.use_chat_template: | |
prompt = [ | |
{ | |
"role": "system", | |
"content": system_prompt, | |
}, | |
{"role": "user", "content": instruction}, | |
] | |
instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False) | |
return instruction | |
def process_item(self, data_item): | |
assert data_item['instruction'] is not None | |
if 'input_images' in data_item and data_item['input_images'] is not None: | |
input_images_path = data_item['input_images'] | |
input_images = [] | |
for input_image_path in input_images_path: | |
input_image = Image.open(input_image_path).convert("RGB") | |
input_images.append(input_image) | |
else: | |
input_images_path, input_images = None, None | |
if input_images is not None and len(input_images) == 1 and self.align_res: | |
target_img_size = (input_images[0].width, input_images[0].height) | |
else: | |
target_img_size = data_item["target_img_size"] | |
w, h = target_img_size | |
cur_pixels = w * h | |
ratio = min(1, (self.max_pixels / cur_pixels) ** 0.5) | |
target_img_size = (int(w * ratio) // self.img_scale_num * self.img_scale_num, int(h * ratio) // self.img_scale_num * self.img_scale_num) | |
data = { | |
'task_type': data_item['task_type'], | |
'instruction': data_item['instruction'], | |
'input_images_path': input_images_path, | |
'input_images': input_images, | |
'target_img_size': target_img_size, | |
} | |
return data | |
def __getitem__(self, index): | |
data_item = self.data[index] | |
return self.process_item(data_item) | |
def __len__(self): | |
return len(self.data) | |
class OmniGen2Collator(): | |
def __init__(self, tokenizer, max_token_len): | |
self.tokenizer = tokenizer | |
self.max_token_len = max_token_len | |
def __call__(self, batch): | |
task_type = [data['task_type'] for data in batch] | |
instruction = [data['instruction'] for data in batch] | |
input_images_path = [data['input_images_path'] for data in batch] | |
input_images = [data['input_images'] for data in batch] | |
output_image = [data['output_image'] for data in batch] | |
output_image_path = [data['output_image_path'] for data in batch] | |
text_inputs = self.tokenizer( | |
instruction, | |
padding="longest", | |
max_length=self.max_token_len, | |
truncation=True, | |
return_tensors="pt", | |
) | |
data = { | |
"task_type": task_type, | |
"text_ids": text_inputs.input_ids, | |
"text_mask": text_inputs.attention_mask, | |
"input_images": input_images, | |
"input_images_path": input_images_path, | |
"output_image": output_image, | |
"output_image_path": output_image_path, | |
} | |
return data | |