Spaces:
Runtime error
Runtime error
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import torch.distributed as dist | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| import numpy as np | |
| import argparse | |
| import os | |
| import json | |
| from utils.distributed import init_distributed_mode | |
| from language.t5 import T5Embedder | |
| CAPTION_KEY = { | |
| 'blip': 0, | |
| 'llava': 1, | |
| 'llava_first': 2, | |
| } | |
| ################################################################################# | |
| # Training Helper Functions # | |
| ################################################################################# | |
| class CustomDataset(Dataset): | |
| def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False): | |
| img_path_list = [] | |
| for lst_name in sorted(os.listdir(lst_dir))[start: end+1]: | |
| if not lst_name.endswith('.jsonl'): | |
| continue | |
| file_path = os.path.join(lst_dir, lst_name) | |
| with open(file_path, 'r') as file: | |
| for line_idx, line in enumerate(file): | |
| data = json.loads(line) | |
| # caption = data[caption_key] | |
| caption = data['text'][CAPTION_KEY[caption_key]] | |
| code_dir = file_path.split('/')[-1].split('.')[0] | |
| if trunc_caption: | |
| caption = caption.split('.')[0] | |
| img_path_list.append((caption, code_dir, line_idx)) | |
| self.img_path_list = img_path_list | |
| def __len__(self): | |
| return len(self.img_path_list) | |
| def __getitem__(self, index): | |
| caption, code_dir, code_name = self.img_path_list[index] | |
| return caption, code_dir, code_name | |
| ################################################################################# | |
| # Training Loop # | |
| ################################################################################# | |
| def main(args): | |
| """ | |
| Trains a new DiT model. | |
| """ | |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." | |
| # Setup DDP: | |
| # dist.init_process_group("nccl") | |
| init_distributed_mode(args) | |
| rank = dist.get_rank() | |
| device = rank % torch.cuda.device_count() | |
| seed = args.global_seed * dist.get_world_size() + rank | |
| torch.manual_seed(seed) | |
| torch.cuda.set_device(device) | |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
| # Setup a feature folder: | |
| if rank == 0: | |
| os.makedirs(args.t5_path, exist_ok=True) | |
| # Setup data: | |
| print(f"Dataset is preparing...") | |
| dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption) | |
| sampler = DistributedSampler( | |
| dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=rank, | |
| shuffle=False, | |
| seed=args.global_seed | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=1, # important! | |
| shuffle=False, | |
| sampler=sampler, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False | |
| ) | |
| print(f"Dataset contains {len(dataset):,} images") | |
| precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] | |
| assert os.path.exists(args.t5_model_path) | |
| t5_xxl = T5Embedder( | |
| device=device, | |
| local_cache=True, | |
| cache_dir=args.t5_model_path, | |
| dir_or_name=args.t5_model_type, | |
| torch_dtype=precision | |
| ) | |
| for caption, code_dir, code_name in loader: | |
| caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption) | |
| valid_caption_embs = caption_embs[:, :emb_masks.sum()] | |
| x = valid_caption_embs.to(torch.float32).detach().cpu().numpy() | |
| os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True) | |
| np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x) | |
| print(code_name.item()) | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data-path", type=str, required=True) | |
| parser.add_argument("--t5-path", type=str, required=True) | |
| parser.add_argument("--data-start", type=int, required=True) | |
| parser.add_argument("--data-end", type=int, required=True) | |
| parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys())) | |
| parser.add_argument("--trunc-caption", action='store_true', default=False) | |
| parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt') | |
| parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') | |
| parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) | |
| parser.add_argument("--global-seed", type=int, default=0) | |
| parser.add_argument("--num-workers", type=int, default=24) | |
| args = parser.parse_args() | |
| main(args) | |