Spaces:
Runtime error
Runtime error
| import gc | |
| import logging | |
| from utils.dataset import ODERegressionLMDBDataset, cycle | |
| from model import ODERegression | |
| from collections import defaultdict | |
| from utils.misc import ( | |
| set_seed | |
| ) | |
| import torch.distributed as dist | |
| from omegaconf import OmegaConf | |
| import torch | |
| import wandb | |
| import time | |
| import os | |
| from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job | |
| class Trainer: | |
| def __init__(self, config): | |
| self.config = config | |
| self.step = 0 | |
| # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| launch_distributed_job() | |
| global_rank = dist.get_rank() | |
| self.world_size = dist.get_world_size() | |
| self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32 | |
| self.device = torch.cuda.current_device() | |
| self.is_main_process = global_rank == 0 | |
| self.disable_wandb = config.disable_wandb | |
| # use a random seed for the training | |
| if config.seed == 0: | |
| random_seed = torch.randint(0, 10000000, (1,), device=self.device) | |
| dist.broadcast(random_seed, src=0) | |
| config.seed = random_seed.item() | |
| set_seed(config.seed + global_rank) | |
| if self.is_main_process and not self.disable_wandb: | |
| wandb.login(host=config.wandb_host, key=config.wandb_key) | |
| wandb.init( | |
| config=OmegaConf.to_container(config, resolve=True), | |
| name=config.config_name, | |
| mode="online", | |
| entity=config.wandb_entity, | |
| project=config.wandb_project, | |
| dir=config.wandb_save_dir | |
| ) | |
| self.output_path = config.logdir | |
| # Step 2: Initialize the model and optimizer | |
| assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training" | |
| self.model = ODERegression(config, device=self.device) | |
| self.model.generator = fsdp_wrap( | |
| self.model.generator, | |
| sharding_strategy=config.sharding_strategy, | |
| mixed_precision=config.mixed_precision, | |
| wrap_strategy=config.generator_fsdp_wrap_strategy | |
| ) | |
| self.model.text_encoder = fsdp_wrap( | |
| self.model.text_encoder, | |
| sharding_strategy=config.sharding_strategy, | |
| mixed_precision=config.mixed_precision, | |
| wrap_strategy=config.text_encoder_fsdp_wrap_strategy, | |
| cpu_offload=getattr(config, "text_encoder_cpu_offload", False) | |
| ) | |
| if not config.no_visualize or config.load_raw_video: | |
| self.model.vae = self.model.vae.to( | |
| device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32) | |
| self.generator_optimizer = torch.optim.AdamW( | |
| [param for param in self.model.generator.parameters() | |
| if param.requires_grad], | |
| lr=config.lr, | |
| betas=(config.beta1, config.beta2), | |
| weight_decay=config.weight_decay | |
| ) | |
| # Step 3: Initialize the dataloader | |
| dataset = ODERegressionLMDBDataset( | |
| config.data_path, max_pair=getattr(config, "max_pair", int(1e8))) | |
| sampler = torch.utils.data.distributed.DistributedSampler( | |
| dataset, shuffle=True, drop_last=True) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8) | |
| total_batch_size = getattr(config, "total_batch_size", None) | |
| if total_batch_size is not None: | |
| assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training" | |
| self.dataloader = cycle(dataloader) | |
| self.step = 0 | |
| ############################################################################################################## | |
| # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts | |
| if getattr(config, "generator_ckpt", False): | |
| print(f"Loading pretrained generator from {config.generator_ckpt}") | |
| state_dict = torch.load(config.generator_ckpt, map_location="cpu")[ | |
| 'generator'] | |
| self.model.generator.load_state_dict( | |
| state_dict, strict=True | |
| ) | |
| ############################################################################################################## | |
| self.max_grad_norm = 10.0 | |
| self.previous_time = None | |
| def save(self): | |
| print("Start gathering distributed model states...") | |
| generator_state_dict = fsdp_state_dict( | |
| self.model.generator) | |
| state_dict = { | |
| "generator": generator_state_dict | |
| } | |
| if self.is_main_process: | |
| os.makedirs(os.path.join(self.output_path, | |
| f"checkpoint_model_{self.step:06d}"), exist_ok=True) | |
| torch.save(state_dict, os.path.join(self.output_path, | |
| f"checkpoint_model_{self.step:06d}", "model.pt")) | |
| print("Model saved to", os.path.join(self.output_path, | |
| f"checkpoint_model_{self.step:06d}", "model.pt")) | |
| def train_one_step(self): | |
| VISUALIZE = self.step % 100 == 0 | |
| self.model.eval() # prevent any randomness (e.g. dropout) | |
| # Step 1: Get the next batch of text prompts | |
| batch = next(self.dataloader) | |
| text_prompts = batch["prompts"] | |
| ode_latent = batch["ode_latent"].to( | |
| device=self.device, dtype=self.dtype) | |
| # Step 2: Extract the conditional infos | |
| with torch.no_grad(): | |
| conditional_dict = self.model.text_encoder( | |
| text_prompts=text_prompts) | |
| # Step 3: Train the generator | |
| generator_loss, log_dict = self.model.generator_loss( | |
| ode_latent=ode_latent, | |
| conditional_dict=conditional_dict | |
| ) | |
| unnormalized_loss = log_dict["unnormalized_loss"] | |
| timestep = log_dict["timestep"] | |
| if self.world_size > 1: | |
| gathered_unnormalized_loss = torch.zeros( | |
| [self.world_size, *unnormalized_loss.shape], | |
| dtype=unnormalized_loss.dtype, device=self.device) | |
| gathered_timestep = torch.zeros( | |
| [self.world_size, *timestep.shape], | |
| dtype=timestep.dtype, device=self.device) | |
| dist.all_gather_into_tensor( | |
| gathered_unnormalized_loss, unnormalized_loss) | |
| dist.all_gather_into_tensor(gathered_timestep, timestep) | |
| else: | |
| gathered_unnormalized_loss = unnormalized_loss | |
| gathered_timestep = timestep | |
| loss_breakdown = defaultdict(list) | |
| stats = {} | |
| for index, t in enumerate(timestep): | |
| loss_breakdown[str(int(t.item()) // 250 * 250)].append( | |
| unnormalized_loss[index].item()) | |
| for key_t in loss_breakdown.keys(): | |
| stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \ | |
| len(loss_breakdown[key_t]) | |
| self.generator_optimizer.zero_grad() | |
| generator_loss.backward() | |
| generator_grad_norm = self.model.generator.clip_grad_norm_( | |
| self.max_grad_norm) | |
| self.generator_optimizer.step() | |
| # Step 4: Visualization | |
| if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process: | |
| # Visualize the input, output, and ground truth | |
| input = log_dict["input"] | |
| output = log_dict["output"] | |
| ground_truth = ode_latent[:, -1] | |
| input_video = self.model.vae.decode_to_pixel(input) | |
| output_video = self.model.vae.decode_to_pixel(output) | |
| ground_truth_video = self.model.vae.decode_to_pixel(ground_truth) | |
| input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5) | |
| output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5) | |
| ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5) | |
| # Visualize the input, output, and ground truth | |
| wandb.log({ | |
| "input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"), | |
| "output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"), | |
| "ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"), | |
| }, step=self.step) | |
| # Step 5: Logging | |
| if self.is_main_process and not self.disable_wandb: | |
| wandb_loss_dict = { | |
| "generator_loss": generator_loss.item(), | |
| "generator_grad_norm": generator_grad_norm.item(), | |
| **stats | |
| } | |
| wandb.log(wandb_loss_dict, step=self.step) | |
| if self.step % self.config.gc_interval == 0: | |
| if dist.get_rank() == 0: | |
| logging.info("DistGarbageCollector: Running GC.") | |
| gc.collect() | |
| def train(self): | |
| while True: | |
| self.train_one_step() | |
| if (not self.config.no_save) and self.step % self.config.log_iters == 0: | |
| self.save() | |
| torch.cuda.empty_cache() | |
| barrier() | |
| if self.is_main_process: | |
| current_time = time.time() | |
| if self.previous_time is None: | |
| self.previous_time = current_time | |
| else: | |
| if not self.disable_wandb: | |
| wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step) | |
| self.previous_time = current_time | |
| self.step += 1 | |