Spaces:
Runtime error
Runtime error
| from typing import Tuple | |
| from einops import rearrange | |
| from torch import nn | |
| import torch.distributed as dist | |
| import torch | |
| from pipeline import SelfForcingTrainingPipeline | |
| from utils.loss import get_denoising_loss | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
| class BaseModel(nn.Module): | |
| def __init__(self, args, device): | |
| super().__init__() | |
| self._initialize_models(args, device) | |
| self.device = device | |
| self.args = args | |
| self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32 | |
| if hasattr(args, "denoising_step_list"): | |
| self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long) | |
| if args.warp_denoising_step: | |
| timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) | |
| self.denoising_step_list = timesteps[1000 - self.denoising_step_list] | |
| def _initialize_models(self, args, device): | |
| self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B") | |
| self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B") | |
| self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
| self.generator.model.requires_grad_(True) | |
| self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False) | |
| self.real_score.model.requires_grad_(False) | |
| self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False) | |
| self.fake_score.model.requires_grad_(True) | |
| self.text_encoder = WanTextEncoder() | |
| self.text_encoder.requires_grad_(False) | |
| self.vae = WanVAEWrapper() | |
| self.vae.requires_grad_(False) | |
| self.scheduler = self.generator.get_scheduler() | |
| self.scheduler.timesteps = self.scheduler.timesteps.to(device) | |
| def _get_timestep( | |
| self, | |
| min_timestep: int, | |
| max_timestep: int, | |
| batch_size: int, | |
| num_frame: int, | |
| num_frame_per_block: int, | |
| uniform_timestep: bool = False | |
| ) -> torch.Tensor: | |
| """ | |
| Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep | |
| from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame]. | |
| - If uniform_timestep, it will use the same timestep for all frames. | |
| - If not uniform_timestep, it will use a different timestep for each block. | |
| """ | |
| if uniform_timestep: | |
| timestep = torch.randint( | |
| min_timestep, | |
| max_timestep, | |
| [batch_size, 1], | |
| device=self.device, | |
| dtype=torch.long | |
| ).repeat(1, num_frame) | |
| return timestep | |
| else: | |
| timestep = torch.randint( | |
| min_timestep, | |
| max_timestep, | |
| [batch_size, num_frame], | |
| device=self.device, | |
| dtype=torch.long | |
| ) | |
| # make the noise level the same within every block | |
| if self.independent_first_frame: | |
| # the first frame is always kept the same | |
| timestep_from_second = timestep[:, 1:] | |
| timestep_from_second = timestep_from_second.reshape( | |
| timestep_from_second.shape[0], -1, num_frame_per_block) | |
| timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1] | |
| timestep_from_second = timestep_from_second.reshape( | |
| timestep_from_second.shape[0], -1) | |
| timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1) | |
| else: | |
| timestep = timestep.reshape( | |
| timestep.shape[0], -1, num_frame_per_block) | |
| timestep[:, :, 1:] = timestep[:, :, 0:1] | |
| timestep = timestep.reshape(timestep.shape[0], -1) | |
| return timestep | |
| class SelfForcingModel(BaseModel): | |
| def __init__(self, args, device): | |
| super().__init__(args, device) | |
| self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)() | |
| def _run_generator( | |
| self, | |
| image_or_video_shape, | |
| conditional_dict: dict, | |
| initial_latent: torch.tensor = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Optionally simulate the generator's input from noise using backward simulation | |
| and then run the generator for one-step. | |
| Input: | |
| - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
| - initial_latent: a tensor containing the initial latents [B, F, C, H, W]. | |
| Output: | |
| - pred_image: a tensor with shape [B, F, C, H, W]. | |
| - denoised_timestep: an integer | |
| """ | |
| # Step 1: Sample noise and backward simulate the generator's input | |
| assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled" | |
| if initial_latent is not None: | |
| conditional_dict["initial_latent"] = initial_latent | |
| if self.args.i2v: | |
| noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]] | |
| else: | |
| noise_shape = image_or_video_shape.copy() | |
| # During training, the number of generated frames should be uniformly sampled from | |
| # [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block | |
| min_num_frames = 20 if self.args.independent_first_frame else 21 | |
| max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames | |
| assert max_num_frames % self.num_frame_per_block == 0 | |
| assert min_num_frames % self.num_frame_per_block == 0 | |
| max_num_blocks = max_num_frames // self.num_frame_per_block | |
| min_num_blocks = min_num_frames // self.num_frame_per_block | |
| num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device) | |
| dist.broadcast(num_generated_blocks, src=0) | |
| num_generated_blocks = num_generated_blocks.item() | |
| num_generated_frames = num_generated_blocks * self.num_frame_per_block | |
| if self.args.independent_first_frame and initial_latent is None: | |
| num_generated_frames += 1 | |
| min_num_frames += 1 | |
| # Sync num_generated_frames across all processes | |
| noise_shape[1] = num_generated_frames | |
| pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation( | |
| noise=torch.randn(noise_shape, | |
| device=self.device, dtype=self.dtype), | |
| **conditional_dict, | |
| ) | |
| # Slice last 21 frames | |
| if pred_image_or_video.shape[1] > 21: | |
| with torch.no_grad(): | |
| # Reencode to get image latent | |
| latent_to_decode = pred_image_or_video[:, :-20, ...] | |
| # Deccode to video | |
| pixels = self.vae.decode_to_pixel(latent_to_decode) | |
| frame = pixels[:, -1:, ...].to(self.dtype) | |
| frame = rearrange(frame, "b t c h w -> b c t h w") | |
| # Encode frame to get image latent | |
| image_latent = self.vae.encode_to_latent(frame).to(self.dtype) | |
| pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1) | |
| else: | |
| pred_image_or_video_last_21 = pred_image_or_video | |
| if num_generated_frames != min_num_frames: | |
| # Currently, we do not use gradient for the first chunk, since it contains image latents | |
| gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool) | |
| if self.args.independent_first_frame: | |
| gradient_mask[:, :1] = False | |
| else: | |
| gradient_mask[:, :self.num_frame_per_block] = False | |
| else: | |
| gradient_mask = None | |
| pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype) | |
| return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to | |
| def _consistency_backward_simulation( | |
| self, | |
| noise: torch.Tensor, | |
| **conditional_dict: dict | |
| ) -> torch.Tensor: | |
| """ | |
| Simulate the generator's input from noise to avoid training/inference mismatch. | |
| See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. | |
| Here we use the consistency sampler (https://arxiv.org/abs/2303.01469) | |
| Input: | |
| - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| Output: | |
| - output: a tensor with shape [B, T, F, C, H, W]. | |
| T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0 | |
| represents the x0 prediction at each timestep. | |
| """ | |
| if self.inference_pipeline is None: | |
| self._initialize_inference_pipeline() | |
| return self.inference_pipeline.inference_with_trajectory( | |
| noise=noise, **conditional_dict | |
| ) | |
| def _initialize_inference_pipeline(self): | |
| """ | |
| Lazy initialize the inference pipeline during the first backward simulation run. | |
| Here we encapsulate the inference code with a model-dependent outside function. | |
| We pass our FSDP-wrapped modules into the pipeline to save memory. | |
| """ | |
| self.inference_pipeline = SelfForcingTrainingPipeline( | |
| denoising_step_list=self.denoising_step_list, | |
| scheduler=self.scheduler, | |
| generator=self.generator, | |
| num_frame_per_block=self.num_frame_per_block, | |
| independent_first_frame=self.args.independent_first_frame, | |
| same_step_across_blocks=self.args.same_step_across_blocks, | |
| last_step_only=self.args.last_step_only, | |
| num_max_frames=self.num_training_frames, | |
| context_noise=self.args.context_noise | |
| ) | |