UMO_OmniGen2 / omnigen2 /dataset /omnigen2_test_dataset.py
cb1cyf's picture
fix: omnigen2
cf4796c
from typing import Optional
import os
import random
import yaml
import glob
from PIL import Image
import torch
from datasets import load_dataset, concatenate_datasets
from ..pipelines.omnigen2.pipeline_omnigen2 import OmniGen2ImageProcessor
class OmniGen2TestDataset(torch.utils.data.Dataset):
SYSTEM_PROMPT = "You are a helpful assistant that generates high-quality images based on user instructions."
def __init__(
self,
config_path: str,
tokenizer,
use_chat_template: bool,
max_pixels: Optional[int] = None,
max_side_length: Optional[int] = None,
img_scale_num: int = 16,
align_res: bool = True
):
self.max_pixels = max_pixels
self.max_side_length = max_side_length
self.img_scale_num = img_scale_num
self.align_res = align_res
with open(config_path, "r") as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
self.use_chat_template = use_chat_template
self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=img_scale_num, do_resize=True)
data = self._collect_annotations(self.config)
self.data = data
self.tokenizer = tokenizer
def _collect_annotations(self, config):
json_datasets = []
for data in config['data']:
data_path, data_type = data['path'], data.get("type", "default")
if os.path.isdir(data_path):
jsonl_files = list(glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)) + list(glob.glob(os.path.join(data_path, "**/*.json"), recursive=True))
json_dataset = load_dataset('json', data_files=jsonl_files, cache_dir=None)['train']
else:
data_ext = os.path.splitext(data_path)[-1]
if data_ext in [".json", ".jsonl"]:
json_dataset = load_dataset('json', data_files=data_path, cache_dir=None)['train']
elif data_ext in [".yml", ".yaml"]:
with open(data_path, "r") as f:
sub_config = yaml.load(f, Loader=yaml.FullLoader)
json_dataset = self._collect_annotations(sub_config)
else:
raise NotImplementedError(
f'Unknown data file extension: "{data_ext}". '
f"Currently, .json, .jsonl .yml .yaml are supported. "
"If you are using a supported format, please set the file extension so that the proper parsing "
"routine can be called."
)
json_datasets.append(json_dataset)
json_dataset = concatenate_datasets(json_datasets)
return json_dataset
def apply_chat_template(self, instruction, system_prompt):
if self.use_chat_template:
prompt = [
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": instruction},
]
instruction = self.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
return instruction
def process_item(self, data_item):
assert data_item['instruction'] is not None
if 'input_images' in data_item and data_item['input_images'] is not None:
input_images_path = data_item['input_images']
input_images = []
for input_image_path in input_images_path:
input_image = Image.open(input_image_path).convert("RGB")
input_images.append(input_image)
else:
input_images_path, input_images = None, None
if input_images is not None and len(input_images) == 1 and self.align_res:
target_img_size = (input_images[0].width, input_images[0].height)
else:
target_img_size = data_item["target_img_size"]
w, h = target_img_size
cur_pixels = w * h
ratio = min(1, (self.max_pixels / cur_pixels) ** 0.5)
target_img_size = (int(w * ratio) // self.img_scale_num * self.img_scale_num, int(h * ratio) // self.img_scale_num * self.img_scale_num)
data = {
'task_type': data_item['task_type'],
'instruction': data_item['instruction'],
'input_images_path': input_images_path,
'input_images': input_images,
'target_img_size': target_img_size,
}
return data
def __getitem__(self, index):
data_item = self.data[index]
return self.process_item(data_item)
def __len__(self):
return len(self.data)
class OmniGen2Collator():
def __init__(self, tokenizer, max_token_len):
self.tokenizer = tokenizer
self.max_token_len = max_token_len
def __call__(self, batch):
task_type = [data['task_type'] for data in batch]
instruction = [data['instruction'] for data in batch]
input_images_path = [data['input_images_path'] for data in batch]
input_images = [data['input_images'] for data in batch]
output_image = [data['output_image'] for data in batch]
output_image_path = [data['output_image_path'] for data in batch]
text_inputs = self.tokenizer(
instruction,
padding="longest",
max_length=self.max_token_len,
truncation=True,
return_tensors="pt",
)
data = {
"task_type": task_type,
"text_ids": text_inputs.input_ids,
"text_mask": text_inputs.attention_mask,
"input_images": input_images,
"input_images_path": input_images_path,
"output_image": output_image,
"output_image_path": output_image_path,
}
return data