Spaces:
Runtime error
Runtime error
| ######################################## | |
| # python -m train | |
| ########################################### | |
| import torch | |
| import logging | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from data import get_dataloader | |
| from model import get_model_and_tokenizer, get_optimizer | |
| import config | |
| logging.basicConfig(level=logging.INFO) | |
| def get_loss(model, input, target, tokenizer): | |
| with torch.no_grad(): | |
| assert len(input.shape) == 5 # [batch, s, c, w, h] | |
| cuts = config.number_k_clip_embed | |
| assert input.shape[0] * input.shape[1] % cuts == 0, 'batch size * `k` preferred embeds must be divisible by cuts' | |
| input = input.view(cuts//8, -1, 3, target.shape[-2], target.shape[-1]) | |
| full_seq = [] | |
| for b in input: | |
| input = tokenizer(b)['image_embeds'] # in our case, tokenizer is a clip embedding model | |
| full_seq.append(input) | |
| input = torch.stack(full_seq) | |
| target = tokenizer(target)['image_embeds'] | |
| input = input.view(target.shape[0], -1, target.shape[-1]) | |
| assert len(input.shape) == 3 # [batch, sequence, inner] | |
| with torch.cuda.amp.autocast(enabled=False, ): | |
| input = input.to(torch.float32) | |
| latent = torch.randn(input.shape[0], input.shape[-1], device=input.device) | |
| output = model(latent, input).predicted_image_embedding | |
| target = target.to(torch.float32) | |
| mse_loss = torch.nn.functional.mse_loss(target, output).mean() | |
| assert len(target.shape) == 2 and len(output.shape) == 2 | |
| cosine_loss = 1 - torch.nn.functional.cosine_similarity(output, target).mean() | |
| loss = mse_loss + .2 * cosine_loss | |
| logging.info(f'MSE: {mse_loss.item()}, Cosine: {cosine_loss.item()}, Weighted Total: {loss.item()}') | |
| # TODO wandb | |
| return loss | |
| def main(): | |
| np.random.seed(config.seed) | |
| torch.manual_seed(config.seed) | |
| model, tokenizer = get_model_and_tokenizer(config.model_path, config.device, config.dtype) | |
| optimizer = get_optimizer(list(model.prior.parameters()), config.lr) | |
| dataloader = get_dataloader(config.data_path, config.batch_size, config.num_workers, | |
| model.prior_pipe.image_processor) | |
| for epoch in range(config.epochs): | |
| for ind, batch in tqdm(enumerate(iter(dataloader))): | |
| if batch is None: | |
| continue | |
| input, target = batch | |
| input = input.to(config.device) | |
| target = target.to(config.device) | |
| if ind % 50 == 0: | |
| with torch.cuda.amp.autocast(enabled=True, dtype=config.dtype): # NOTE using autocast because our training model is also our val model, so don't want to set to full half precision. | |
| examples = ['../generative_recommender/Blue_Tigers_space/1o.png', | |
| '../generative_recommender/Blue_Tigers_space/2o.png', | |
| '../generative_recommender/Blue_Tigers_space/3o.png', | |
| '../generative_recommender/Blue_Tigers_space/4o.png', | |
| '../generative_recommender/Blue_Tigers_space/5o.png', | |
| '../generative_recommender/Blue_Tigers_space/6o.png', | |
| '../generative_recommender/Blue_Tigers_space/7o.png', | |
| '../generative_recommender/Blue_Tigers_space/8o.png',] | |
| model.do_validation([[Image.open('../'+j) for j in examples]]) | |
| loss = get_loss(model, input, target, tokenizer) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if ind % 100 == 0: | |
| # TODO add loading from path | |
| model.prior.save_pretrained(f'{config.save_path}/last_epoch_ckpt', from_pt=True) | |
| if __name__ == '__main__': | |
| main() | |