Tzktz's picture
Upload 7664 files
6fc683c verified
"""
this script will convert the original data format to tsv format, in which jpg files are converted to base64 strings and
saved in a resolution of 512(shorter side).
each contains 3 columns: input, edit, seed0_0, seed0_1, ..., seedN_0, seedN_1
each tsv file should contain 1000 samples.
"""
import argparse
import base64
import io
import json
import os
from multiprocessing import Process
from PIL import Image
from tqdm import tqdm
def save_tsv(args, i, sub_seeds_list):
with open(os.path.join(args.output_dir, f'{str(i).zfill(4)}.tsv'), 'w') as f:
for name, seeds in tqdm(sub_seeds_list, desc=f'processing {i}th tsv file', leave=False):
# load prompt
prompt = json.load(open(os.path.join(args.data_dir, name, 'prompt.json')))
# load images
images = [Image.open(os.path.join(args.data_dir, name, f'{seed}_{j}.jpg')).convert('RGB') for seed in seeds
for j in range(2)]
# resize the shorter side to 512
images = [im.resize((512, int(512 / im.size[0] * im.size[1])) if im.size[0] < im.size[1] else
(int(512 / im.size[1] * im.size[0]), 512)) for im in images]
# encode image using base64
for j, im in enumerate(images):
buffer = io.BytesIO()
im.save(buffer, format='PNG')
images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8')
# write to tsv file
f.write('\t'.join([
prompt['input'].replace('\t', '').replace('\n', '').strip(),
prompt['edit'].replace('\t', '').replace('\n', '').strip(),
*images
]) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, default='/path/to/clip-filtered-dataset/')
parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/')
parser.add_argument('--max-image', type=int, default=1000)
parser.add_argument('--num-process', type=int, default=64)
args = parser.parse_args()
# load seeds
seeds_list = json.load(open(os.path.join(args.data_dir, 'seeds.json')))
# split seeds into 1000 samples per tsv file
seeds_list = [seeds_list[i:i + args.max_image] for i in range(0, len(seeds_list), args.max_image)]
# save tsv files
processes = []
for i, sub_seeds_list in enumerate(seeds_list):
p = Process(target=save_tsv, args=(args, i, sub_seeds_list))
p.start()
processes.append(p)
if len(processes) == args.num_process:
for p in processes:
p.join()
processes = []
for p in processes:
p.join()
print('done.')