Spaces:
Sleeping
Sleeping
""" | |
this script will convert the original data format to tsv format, in which jpg files are converted to base64 strings and | |
saved in a resolution of 512(shorter side). | |
each contains 3 columns: input, edit, seed0_0, seed0_1, ..., seedN_0, seedN_1 | |
each tsv file should contain 1000 samples. | |
""" | |
import argparse | |
import base64 | |
import io | |
import json | |
import os | |
from multiprocessing import Process | |
from PIL import Image | |
from tqdm import tqdm | |
def save_tsv(args, i, sub_seeds_list): | |
with open(os.path.join(args.output_dir, f'{str(i).zfill(4)}.tsv'), 'w') as f: | |
for name, seeds in tqdm(sub_seeds_list, desc=f'processing {i}th tsv file', leave=False): | |
# load prompt | |
prompt = json.load(open(os.path.join(args.data_dir, name, 'prompt.json'))) | |
# load images | |
images = [Image.open(os.path.join(args.data_dir, name, f'{seed}_{j}.jpg')).convert('RGB') for seed in seeds | |
for j in range(2)] | |
# resize the shorter side to 512 | |
images = [im.resize((512, int(512 / im.size[0] * im.size[1])) if im.size[0] < im.size[1] else | |
(int(512 / im.size[1] * im.size[0]), 512)) for im in images] | |
# encode image using base64 | |
for j, im in enumerate(images): | |
buffer = io.BytesIO() | |
im.save(buffer, format='PNG') | |
images[j] = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
# write to tsv file | |
f.write('\t'.join([ | |
prompt['input'].replace('\t', '').replace('\n', '').strip(), | |
prompt['edit'].replace('\t', '').replace('\n', '').strip(), | |
*images | |
]) + '\n') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data-dir', type=str, default='/path/to/clip-filtered-dataset/') | |
parser.add_argument('--output-dir', type=str, default='/path/to/output-dir/') | |
parser.add_argument('--max-image', type=int, default=1000) | |
parser.add_argument('--num-process', type=int, default=64) | |
args = parser.parse_args() | |
# load seeds | |
seeds_list = json.load(open(os.path.join(args.data_dir, 'seeds.json'))) | |
# split seeds into 1000 samples per tsv file | |
seeds_list = [seeds_list[i:i + args.max_image] for i in range(0, len(seeds_list), args.max_image)] | |
# save tsv files | |
processes = [] | |
for i, sub_seeds_list in enumerate(seeds_list): | |
p = Process(target=save_tsv, args=(args, i, sub_seeds_list)) | |
p.start() | |
processes.append(p) | |
if len(processes) == args.num_process: | |
for p in processes: | |
p.join() | |
processes = [] | |
for p in processes: | |
p.join() | |
print('done.') | |