Spaces:
Runtime error
Runtime error
| from typing import List | |
| import torch | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
| class BidirectionalInferencePipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| args, | |
| device, | |
| generator=None, | |
| text_encoder=None, | |
| vae=None | |
| ): | |
| super().__init__() | |
| # Step 1: Initialize all models | |
| self.generator = WanDiffusionWrapper( | |
| **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator | |
| self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder | |
| self.vae = WanVAEWrapper() if vae is None else vae | |
| # Step 2: Initialize all bidirectional wan hyperparmeters | |
| self.scheduler = self.generator.get_scheduler() | |
| self.denoising_step_list = torch.tensor( | |
| args.denoising_step_list, dtype=torch.long, device=device) | |
| if self.denoising_step_list[-1] == 0: | |
| self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference | |
| if args.warp_denoising_step: | |
| timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))) | |
| self.denoising_step_list = timesteps[1000 - self.denoising_step_list] | |
| def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor: | |
| """ | |
| Perform inference on the given noise and text prompts. | |
| Inputs: | |
| noise (torch.Tensor): The input noise tensor of shape | |
| (batch_size, num_frames, num_channels, height, width). | |
| text_prompts (List[str]): The list of text prompts. | |
| Outputs: | |
| video (torch.Tensor): The generated video tensor of shape | |
| (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. | |
| """ | |
| conditional_dict = self.text_encoder( | |
| text_prompts=text_prompts | |
| ) | |
| # initial point | |
| noisy_image_or_video = noise | |
| # use the last n-1 timesteps to simulate the generator's input | |
| for index, current_timestep in enumerate(self.denoising_step_list[:-1]): | |
| _, pred_image_or_video = self.generator( | |
| noisy_image_or_video=noisy_image_or_video, | |
| conditional_dict=conditional_dict, | |
| timestep=torch.ones( | |
| noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep | |
| ) # [B, F, C, H, W] | |
| next_timestep = self.denoising_step_list[index + 1] * torch.ones( | |
| noise.shape[:2], dtype=torch.long, device=noise.device) | |
| noisy_image_or_video = self.scheduler.add_noise( | |
| pred_image_or_video.flatten(0, 1), | |
| torch.randn_like(pred_image_or_video.flatten(0, 1)), | |
| next_timestep.flatten(0, 1) | |
| ).unflatten(0, noise.shape[:2]) | |
| video = self.vae.decode_to_pixel(pred_image_or_video) | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| return video | |