Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # -*- encoding: utf-8 -*- | |
| ''' | |
| @File : pretrain_cogvideo.py | |
| @Time : 2021/10/06 00:58:32 | |
| @Author : Wenyi Hong | |
| @Contact : [email protected] | |
| ''' | |
| # here put the import lib | |
| import os | |
| import sys | |
| import math | |
| import random | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from icetk import icetk as tokenizer | |
| tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>']) | |
| from models.cogvideo_model import CogVideoModel | |
| from SwissArmyTransformer import mpu, get_args | |
| from SwissArmyTransformer.training.deepspeed_training import training_main | |
| from SwissArmyTransformer.data_utils import BinaryDataset | |
| def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): | |
| # Extract batch size and sequence length. | |
| batch_size, seq_length = data.size() | |
| assert attention_mask_totxt is not None | |
| layout = args.layout | |
| assert seq_length == layout[-1] | |
| n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() | |
| frame_len = layout[1]-layout[0] | |
| position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, | |
| device=data.device) | |
| for i in range(batch_size): | |
| torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], | |
| dtype=torch.long, device=data.device) | |
| torch.arange(512, 512+layout[2]-layout[0], | |
| out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) | |
| return position_ids | |
| def get_batch(data_iterator, args, timers): | |
| # Items and their type. | |
| keys = ['text', 'loss_mask', 'attention_mask_totxt'] | |
| datatype = torch.int64 | |
| # Broadcast data. | |
| timers('data loader').start() | |
| if data_iterator is not None: | |
| data = next(data_iterator) | |
| else: | |
| data = None | |
| timers('data loader').stop() | |
| data_b = mpu.broadcast_data(keys, data, datatype) | |
| # Unpack. | |
| tokens_ = data_b['text'].long() | |
| loss_mask = data_b['loss_mask'].float() | |
| attention_mask_totxt = data_b['attention_mask_totxt'].float() | |
| labels = tokens_[:, 1:].clone().contiguous() | |
| loss_mask = loss_mask[:, 1:].contiguous() | |
| tokens = tokens_[:, :-1].clone().contiguous() | |
| for idx in range(args.layout[0], args.layout[2], 400): | |
| tokens[:, idx] = tokenizer['<start_of_image>'] | |
| # Get the masks and postition ids. | |
| position_ids = get_masks_and_position_ids_video( | |
| tokens, | |
| attention_mask_totxt=attention_mask_totxt, | |
| args=args | |
| ) | |
| attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) | |
| # Convert | |
| if args.fp16: | |
| attention_mask_totxt = attention_mask_totxt.half() | |
| return tokens, labels, loss_mask, attention_mask_totxt, position_ids | |
| def forward_step(data_iterator, model, args, timers): | |
| """Forward step.""" | |
| # Get the batch. | |
| timers('batch generator').start() | |
| tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( | |
| data_iterator, args, timers) | |
| timers('batch generator').stop() | |
| # Forward model. | |
| logits, *mems = model(tokens, position_ids, attention_mask_totxt) | |
| # ======= hyper params =======# | |
| perframe_len = 400 | |
| text_len=64 | |
| frame_num = 5 | |
| logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() | |
| losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) | |
| # scaling loss mask | |
| loss_mask = loss_mask[:, text_len:].reshape(-1) | |
| losses_1d = losses.reshape(-1) * loss_mask | |
| loss = torch.sum(losses_1d) / loss_mask.sum() | |
| # ===================== Log partial losses ======================== # | |
| log_loss_dict = {} | |
| bs = losses.shape[0] | |
| if args.cogvideo_stage == 1: | |
| for i in range(frame_num): | |
| log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) | |
| else: | |
| for i in range(1, frame_num-1): | |
| log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) | |
| # ===================== END OF BLOCK ======================= # | |
| return loss, log_loss_dict | |
| def create_dataset_function(path, args): | |
| dataset_layout = [64, 464, 2064] | |
| input_layout = [64, 464, 2064] | |
| # frame_num = 6 | |
| # frame_interval = 2 # DEBUG!!! | |
| def process_fn(row): | |
| row = row.astype(np.int64) | |
| text = row[:dataset_layout[0]] | |
| frames = row[dataset_layout[0]:] | |
| if text[0] == tokenizer['<pad>']: | |
| text = text[1:] # due to our way of data processing | |
| if args.cogvideo_stage == 1: | |
| text, loss_mask, frames = make_text_video_generation(text, frames) | |
| else: | |
| text, loss_mask, frames = mask_video_frame_interpolation(text, frames) | |
| n_pad = input_layout[0] - len(text) | |
| parts = [ | |
| np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64), | |
| text, | |
| np.array([tokenizer['<start_of_image>']], dtype=np.int64), | |
| frames, | |
| ] | |
| ret = np.concatenate(parts, axis=0) | |
| attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) | |
| return {'text': ret, | |
| 'loss_mask': loss_mask, | |
| 'attention_mask_totxt': attention_mask_totxt, | |
| } | |
| return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) | |
| def make_text_video_generation(text, frames): | |
| input_layout = [64, 464, 2064] | |
| text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ... | |
| loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 | |
| return text, loss_mask, frames | |
| def mask_video_frame_interpolation(text, frames): | |
| input_layout = [64, 464, 2064] | |
| frame_len = input_layout[1]-input_layout[0] | |
| # text format: <pad> 1.0秒 <n> {text} <pad> <pad> | |
| text = text[text!= tokenizer['<pad>']][:input_layout[0]] | |
| loss_mask = np.array([0] * (input_layout[1]+1) | |
| + [1] * (input_layout[1]-input_layout[0]) | |
| + [0] * (input_layout[1]-input_layout[0]) | |
| + [1] * (input_layout[1]-input_layout[0]) | |
| + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 | |
| return text, loss_mask, frames | |
| if __name__ == '__main__': | |
| py_parser = argparse.ArgumentParser(add_help=False) | |
| py_parser.add_argument('--txt-loss-scale', type=float, default=1) | |
| CogVideoModel.add_model_specific_args(py_parser) | |
| known, args_list = py_parser.parse_known_args() | |
| args = get_args(args_list) | |
| args = argparse.Namespace(**vars(args), **vars(known)) | |
| args.layout = [int(x) for x in args.layout.split(',')] | |
| training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) | |

