Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| from typing import Tuple | |
| import torch | |
| from model.base import BaseModel | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
| class ODERegression(BaseModel): | |
| def __init__(self, args, device): | |
| """ | |
| Initialize the ODERegression module. | |
| This class is self-contained and compute generator losses | |
| in the forward pass given precomputed ode solution pairs. | |
| This class supports the ode regression loss for both causal and bidirectional models. | |
| See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details | |
| """ | |
| super().__init__(args, device) | |
| # Step 1: Initialize all models | |
| self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
| self.generator.model.requires_grad_(True) | |
| if getattr(args, "generator_ckpt", False): | |
| print(f"Loading pretrained generator from {args.generator_ckpt}") | |
| state_dict = torch.load(args.generator_ckpt, map_location="cpu")[ | |
| 'generator'] | |
| self.generator.load_state_dict( | |
| state_dict, strict=True | |
| ) | |
| self.num_frame_per_block = getattr(args, "num_frame_per_block", 1) | |
| if self.num_frame_per_block > 1: | |
| self.generator.model.num_frame_per_block = self.num_frame_per_block | |
| self.independent_first_frame = getattr(args, "independent_first_frame", False) | |
| if self.independent_first_frame: | |
| self.generator.model.independent_first_frame = True | |
| if args.gradient_checkpointing: | |
| self.generator.enable_gradient_checkpointing() | |
| # Step 2: Initialize all hyperparameters | |
| self.timestep_shift = getattr(args, "timestep_shift", 1.0) | |
| def _initialize_models(self, args): | |
| self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
| self.generator.model.requires_grad_(True) | |
| self.text_encoder = WanTextEncoder() | |
| self.text_encoder.requires_grad_(False) | |
| self.vae = WanVAEWrapper() | |
| self.vae.requires_grad_(False) | |
| def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Given a tensor containing the whole ODE sampling trajectories, | |
| randomly choose an intermediate timestep and return the latent as well as the corresponding timestep. | |
| Input: | |
| - ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width]. | |
| Output: | |
| - noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width]. | |
| - timestep: a tensor containing the corresponding timestep [batch_size]. | |
| """ | |
| batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape | |
| # Step 1: Randomly choose a timestep for each frame | |
| index = self._get_timestep( | |
| 0, | |
| len(self.denoising_step_list), | |
| batch_size, | |
| num_frames, | |
| self.num_frame_per_block, | |
| uniform_timestep=False | |
| ) | |
| if self.args.i2v: | |
| index[:, 0] = len(self.denoising_step_list) - 1 | |
| noisy_input = torch.gather( | |
| ode_latent, dim=1, | |
| index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand( | |
| -1, -1, -1, num_channels, height, width).to(self.device) | |
| ).squeeze(1) | |
| timestep = self.denoising_step_list[index].to(self.device) | |
| # if self.extra_noise_step > 0: | |
| # random_timestep = torch.randint(0, self.extra_noise_step, [ | |
| # batch_size, num_frames], device=self.device, dtype=torch.long) | |
| # perturbed_noisy_input = self.scheduler.add_noise( | |
| # noisy_input.flatten(0, 1), | |
| # torch.randn_like(noisy_input.flatten(0, 1)), | |
| # random_timestep.flatten(0, 1) | |
| # ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input) | |
| # noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0] | |
| return noisy_input, timestep | |
| def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Generate image/videos from noisy latents and compute the ODE regression loss. | |
| Input: | |
| - ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width]. | |
| They are ordered from most noisy to clean latents. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| Output: | |
| - loss: a scalar tensor representing the generator loss. | |
| - log_dict: a dictionary containing additional information for loss timestep breakdown. | |
| """ | |
| # Step 1: Run generator on noisy latents | |
| target_latent = ode_latent[:, -1] | |
| noisy_input, timestep = self._prepare_generator_input( | |
| ode_latent=ode_latent) | |
| _, pred_image_or_video = self.generator( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep | |
| ) | |
| # Step 2: Compute the regression loss | |
| mask = timestep != 0 | |
| loss = F.mse_loss( | |
| pred_image_or_video[mask], target_latent[mask], reduction="mean") | |
| log_dict = { | |
| "unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(), | |
| "timestep": timestep.float().mean(dim=1).detach(), | |
| "input": noisy_input.detach(), | |
| "output": pred_image_or_video.detach(), | |
| } | |
| return loss, log_dict | |