File size: 3,627 Bytes
d0e893e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import csv
import json
import os
import random
import ast
import numpy as np
from omegaconf import OmegaConf
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from safetensors.torch import save_file, load_file
from .sampler_utils import get_train_sampler, get_packed_batch_sampler
def resize_arr(pil_image, height, width):
pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)
return pil_image
class T2IDatasetMS(Dataset):
def __init__(self, root_dir, packed_json, jsonl_dir) -> None:
super().__init__()
self.root_dir = root_dir
self.dataset = []
with open(packed_json, 'r') as fp:
self.packed_dataset = json.load(fp)
with open(jsonl_dir, 'r') as fp:
self.dataset = [json.loads(line) for line in fp]
def __len__(self):
return len(self.dataset)
def get_one_data(self, data_meta):
data_item = dict()
image_file = os.path.join(self.root_dir, data_meta['image_file'])
image = Image.open(image_file).convert("RGB")
bucket = data_meta['bucket']
resolutions = bucket.split('-')[-1].split('x')
height, width = int(int(resolutions[0])/32)*32, int(int(resolutions[1])/32)*32
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: resize_arr(pil_image, height, width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
])
image = transform(image)
data_item['image'] = image
data_item['caption'] = random.choice(data_meta['captions']).encode('unicode-escape').decode('utf-8')
return data_item
def __getitem__(self, index):
data_meta = self.dataset[index]
# data_item = self.get_one_data(data_meta)
try:
data_item = self.get_one_data(data_meta)
except:
print(f"Warning: {data_meta['image_file']} does not exist", flush=True)
data_item = None
return data_item
def bucket_collate_fn(batch):
caption = []
image = []
for data in batch:
if data == None:
continue
caption.append(data['caption'])
image.append(data['image'])
image = torch.stack(image)
return dict(image=image, caption=caption)
class T2ILoader():
def __init__(self, data_config):
super().__init__()
self.batch_size = data_config.dataloader.batch_size
self.num_workers = data_config.dataloader.num_workers
self.data_type = data_config.data_type
if self.data_type == 'image_ms':
self.train_dataset = T2IDatasetMS(**OmegaConf.to_container(data_config.dataset))
else:
raise
self.test_dataset = None
self.val_dataset = None
def train_len(self):
return len(self.train_dataset)
def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
batch_sampler = get_packed_batch_sampler(
self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed
)
return DataLoader(
self.train_dataset,
batch_sampler=batch_sampler,
collate_fn=bucket_collate_fn,
num_workers=self.num_workers,
pin_memory=True,
)
def test_dataloader(self):
return None
def val_dataloader(self):
return None
|