Spaces:
Running
Running
| import time | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| import sys | |
| import os | |
| import cv2 | |
| import random | |
| from transformers import CLIPImageProcessor | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torchvision.transforms.functional | |
| from toolkit.image_utils import save_tensors, show_img, show_tensors | |
| from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
| from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ | |
| trigger_dataloader_setup_epoch | |
| from toolkit.config_modules import DatasetConfig | |
| import argparse | |
| from tqdm import tqdm | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('dataset_folder', type=str, default='input') | |
| parser.add_argument('--epochs', type=int, default=1) | |
| parser.add_argument('--num_frames', type=int, default=1) | |
| parser.add_argument('--output_path', type=str, default=None) | |
| args = parser.parse_args() | |
| if args.output_path is not None: | |
| args.output_path = os.path.abspath(args.output_path) | |
| os.makedirs(args.output_path, exist_ok=True) | |
| dataset_folder = args.dataset_folder | |
| resolution = 512 | |
| bucket_tolerance = 64 | |
| batch_size = 1 | |
| clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| class FakeAdapter: | |
| def __init__(self): | |
| self.clip_image_processor = clip_processor | |
| ## make fake sd | |
| class FakeSD: | |
| def __init__(self): | |
| self.adapter = FakeAdapter() | |
| dataset_config = DatasetConfig( | |
| dataset_path=dataset_folder, | |
| # clip_image_path=dataset_folder, | |
| # square_crop=True, | |
| resolution=resolution, | |
| # caption_ext='json', | |
| default_caption='default', | |
| # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', | |
| buckets=True, | |
| bucket_tolerance=bucket_tolerance, | |
| shrink_video_to_frames=True, | |
| num_frames=args.num_frames, | |
| # poi='person', | |
| # shuffle_augmentations=True, | |
| # augmentations=[ | |
| # { | |
| # 'method': 'Posterize', | |
| # 'num_bits': [(0, 4), (0, 4), (0, 4)], | |
| # 'p': 1.0 | |
| # }, | |
| # | |
| # ] | |
| ) | |
| dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) | |
| # run through an epoch ang check sizes | |
| dataloader_iterator = iter(dataloader) | |
| idx = 0 | |
| for epoch in range(args.epochs): | |
| for batch in tqdm(dataloader): | |
| batch: 'DataLoaderBatchDTO' | |
| img_batch = batch.tensor | |
| frames = 1 | |
| if len(img_batch.shape) == 5: | |
| frames = img_batch.shape[1] | |
| batch_size, frames, channels, height, width = img_batch.shape | |
| else: | |
| batch_size, channels, height, width = img_batch.shape | |
| # img_batch = color_block_imgs(img_batch, neg1_1=True) | |
| # chunks = torch.chunk(img_batch, batch_size, dim=0) | |
| # # put them so they are size by side | |
| # big_img = torch.cat(chunks, dim=3) | |
| # big_img = big_img.squeeze(0) | |
| # | |
| # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) | |
| # big_control_img = torch.cat(control_chunks, dim=3) | |
| # big_control_img = big_control_img.squeeze(0) * 2 - 1 | |
| # | |
| # | |
| # # resize control image | |
| # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) | |
| # | |
| # big_img = torch.cat([big_img, big_control_img], dim=2) | |
| # | |
| # min_val = big_img.min() | |
| # max_val = big_img.max() | |
| # | |
| # big_img = (big_img / 2 + 0.5).clamp(0, 1) | |
| big_img = img_batch | |
| # big_img = big_img.clamp(-1, 1) | |
| if args.output_path is not None: | |
| save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png')) | |
| else: | |
| show_tensors(big_img) | |
| # convert to image | |
| # img = transforms.ToPILImage()(big_img) | |
| # | |
| # show_img(img) | |
| time.sleep(0.2) | |
| idx += 1 | |
| # if not last epoch | |
| if epoch < args.epochs - 1: | |
| trigger_dataloader_setup_epoch(dataloader) | |
| cv2.destroyAllWindows() | |
| print('done') | |