Spaces:
Runtime error
Runtime error
| from tqdm import tqdm | |
| from typing import List | |
| import torch | |
| from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps | |
| from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
| class BidirectionalDiffusionInferencePipeline(torch.nn.Module): | |
| def __init__( | |
| self, | |
| args, | |
| device, | |
| generator=None, | |
| text_encoder=None, | |
| vae=None | |
| ): | |
| super().__init__() | |
| # Step 1: Initialize all models | |
| self.generator = WanDiffusionWrapper( | |
| **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator | |
| self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder | |
| self.vae = WanVAEWrapper() if vae is None else vae | |
| # Step 2: Initialize scheduler | |
| self.num_train_timesteps = args.num_train_timestep | |
| self.sampling_steps = 50 | |
| self.sample_solver = 'unipc' | |
| self.shift = 8.0 | |
| self.args = args | |
| def inference( | |
| self, | |
| noise: torch.Tensor, | |
| text_prompts: List[str], | |
| return_latents=False | |
| ) -> torch.Tensor: | |
| """ | |
| Perform inference on the given noise and text prompts. | |
| Inputs: | |
| noise (torch.Tensor): The input noise tensor of shape | |
| (batch_size, num_frames, num_channels, height, width). | |
| text_prompts (List[str]): The list of text prompts. | |
| Outputs: | |
| video (torch.Tensor): The generated video tensor of shape | |
| (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1]. | |
| """ | |
| conditional_dict = self.text_encoder( | |
| text_prompts=text_prompts | |
| ) | |
| unconditional_dict = self.text_encoder( | |
| text_prompts=[self.args.negative_prompt] * len(text_prompts) | |
| ) | |
| latents = noise | |
| sample_scheduler = self._initialize_sample_scheduler(noise) | |
| for _, t in enumerate(tqdm(sample_scheduler.timesteps)): | |
| latent_model_input = latents | |
| timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32) | |
| flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep) | |
| flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep) | |
| flow_pred = flow_pred_uncond + self.args.guidance_scale * ( | |
| flow_pred_cond - flow_pred_uncond) | |
| temp_x0 = sample_scheduler.step( | |
| flow_pred.unsqueeze(0), | |
| t, | |
| latents.unsqueeze(0), | |
| return_dict=False)[0] | |
| latents = temp_x0.squeeze(0) | |
| x0 = latents | |
| video = self.vae.decode_to_pixel(x0) | |
| video = (video * 0.5 + 0.5).clamp(0, 1) | |
| del sample_scheduler | |
| if return_latents: | |
| return video, latents | |
| else: | |
| return video | |
| def _initialize_sample_scheduler(self, noise): | |
| if self.sample_solver == 'unipc': | |
| sample_scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=self.num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sample_scheduler.set_timesteps( | |
| self.sampling_steps, device=noise.device, shift=self.shift) | |
| self.timesteps = sample_scheduler.timesteps | |
| elif self.sample_solver == 'dpm++': | |
| sample_scheduler = FlowDPMSolverMultistepScheduler( | |
| num_train_timesteps=self.num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift) | |
| self.timesteps, _ = retrieve_timesteps( | |
| sample_scheduler, | |
| device=noise.device, | |
| sigmas=sampling_sigmas) | |
| else: | |
| raise NotImplementedError("Unsupported solver.") | |
| return sample_scheduler | |