TiM / tim /data /t2i_data.py
Julien Blanchon
Clean Space repo (code only, checkpoints in model repo)
d0e893e
import torch
import csv
import json
import os
import random
import ast
import numpy as np
from omegaconf import OmegaConf
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from safetensors.torch import save_file, load_file
from .sampler_utils import get_train_sampler, get_packed_batch_sampler
def resize_arr(pil_image, height, width):
pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)
return pil_image
class T2IDatasetMS(Dataset):
def __init__(self, root_dir, packed_json, jsonl_dir) -> None:
super().__init__()
self.root_dir = root_dir
self.dataset = []
with open(packed_json, 'r') as fp:
self.packed_dataset = json.load(fp)
with open(jsonl_dir, 'r') as fp:
self.dataset = [json.loads(line) for line in fp]
def __len__(self):
return len(self.dataset)
def get_one_data(self, data_meta):
data_item = dict()
image_file = os.path.join(self.root_dir, data_meta['image_file'])
image = Image.open(image_file).convert("RGB")
bucket = data_meta['bucket']
resolutions = bucket.split('-')[-1].split('x')
height, width = int(int(resolutions[0])/32)*32, int(int(resolutions[1])/32)*32
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: resize_arr(pil_image, height, width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
])
image = transform(image)
data_item['image'] = image
data_item['caption'] = random.choice(data_meta['captions']).encode('unicode-escape').decode('utf-8')
return data_item
def __getitem__(self, index):
data_meta = self.dataset[index]
# data_item = self.get_one_data(data_meta)
try:
data_item = self.get_one_data(data_meta)
except:
print(f"Warning: {data_meta['image_file']} does not exist", flush=True)
data_item = None
return data_item
def bucket_collate_fn(batch):
caption = []
image = []
for data in batch:
if data == None:
continue
caption.append(data['caption'])
image.append(data['image'])
image = torch.stack(image)
return dict(image=image, caption=caption)
class T2ILoader():
def __init__(self, data_config):
super().__init__()
self.batch_size = data_config.dataloader.batch_size
self.num_workers = data_config.dataloader.num_workers
self.data_type = data_config.data_type
if self.data_type == 'image_ms':
self.train_dataset = T2IDatasetMS(**OmegaConf.to_container(data_config.dataset))
else:
raise
self.test_dataset = None
self.val_dataset = None
def train_len(self):
return len(self.train_dataset)
def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
batch_sampler = get_packed_batch_sampler(
self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed
)
return DataLoader(
self.train_dataset,
batch_sampler=batch_sampler,
collate_fn=bucket_collate_fn,
num_workers=self.num_workers,
pin_memory=True,
)
def test_dataloader(self):
return None
def val_dataloader(self):
return None