Spaces:
Running
Running
| # WIP, coming soon ish | |
| from functools import partial | |
| import torch | |
| import yaml | |
| from toolkit.accelerator import unwrap_model | |
| from toolkit.basic import flush | |
| from toolkit.config_modules import GenerateImageConfig, ModelConfig | |
| from toolkit.dequantize import patch_dequantization_on_save | |
| from toolkit.models.base_model import BaseModel | |
| from toolkit.prompt_utils import PromptEmbeds | |
| from transformers import AutoTokenizer, UMT5EncoderModel | |
| from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel | |
| import os | |
| import sys | |
| import weakref | |
| import torch | |
| import yaml | |
| from toolkit.basic import flush | |
| from toolkit.config_modules import GenerateImageConfig, ModelConfig | |
| from toolkit.dequantize import patch_dequantization_on_save | |
| from toolkit.models.base_model import BaseModel | |
| from toolkit.prompt_utils import PromptEmbeds | |
| import os | |
| import copy | |
| from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch | |
| import torch | |
| from optimum.quanto import freeze, qfloat8, QTensor, qint4 | |
| from toolkit.util.quantize import quantize, get_qtype | |
| from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler | |
| from typing import TYPE_CHECKING, List | |
| from toolkit.accelerator import unwrap_model | |
| from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler | |
| from tqdm import tqdm | |
| import torch.nn.functional as F | |
| from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput | |
| from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE | |
| # from ...callbacks import MultiPipelineCallbacks, PipelineCallback | |
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from toolkit.models.wan21.wan_lora_convert import convert_to_diffusers, convert_to_original | |
| # for generation only? | |
| scheduler_configUniPC = { | |
| "_class_name": "UniPCMultistepScheduler", | |
| "_diffusers_version": "0.33.0.dev0", | |
| "beta_end": 0.02, | |
| "beta_schedule": "linear", | |
| "beta_start": 0.0001, | |
| "disable_corrector": [], | |
| "dynamic_thresholding_ratio": 0.995, | |
| "final_sigmas_type": "zero", | |
| "flow_shift": 3.0, | |
| "lower_order_final": True, | |
| "num_train_timesteps": 1000, | |
| "predict_x0": True, | |
| "prediction_type": "flow_prediction", | |
| "rescale_betas_zero_snr": False, | |
| "sample_max_value": 1.0, | |
| "solver_order": 2, | |
| "solver_p": None, | |
| "solver_type": "bh2", | |
| "steps_offset": 0, | |
| "thresholding": False, | |
| "timestep_spacing": "linspace", | |
| "trained_betas": None, | |
| "use_beta_sigmas": False, | |
| "use_exponential_sigmas": False, | |
| "use_flow_sigmas": True, | |
| "use_karras_sigmas": False | |
| } | |
| # for training. I think it is right | |
| scheduler_config = { | |
| "num_train_timesteps": 1000, | |
| "shift": 3.0, | |
| "use_dynamic_shifting": False | |
| } | |
| class AggressiveWanUnloadPipeline(WanPipeline): | |
| def __init__( | |
| self, | |
| tokenizer: AutoTokenizer, | |
| text_encoder: UMT5EncoderModel, | |
| transformer: WanTransformer3DModel, | |
| vae: AutoencoderKLWan, | |
| scheduler: FlowMatchEulerDiscreteScheduler, | |
| device: torch.device = torch.device("cuda"), | |
| ): | |
| super().__init__( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=transformer, | |
| vae=vae, | |
| scheduler=scheduler, | |
| ) | |
| self._exec_device = device | |
| def _execution_device(self): | |
| return self._exec_device | |
| def __call__( | |
| self: WanPipeline, | |
| prompt: Union[str, List[str]] = None, | |
| negative_prompt: Union[str, List[str]] = None, | |
| height: int = 480, | |
| width: int = 832, | |
| num_frames: int = 81, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 5.0, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, | |
| List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "np", | |
| return_dict: bool = True, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| callback_on_step_end: Optional[ | |
| Union[Callable[[int, int, Dict], None], | |
| PipelineCallback, MultiPipelineCallbacks] | |
| ] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 512, | |
| ): | |
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
| # unload vae and transformer | |
| vae_device = self.vae.device | |
| transformer_device = self.transformer.device | |
| text_encoder_device = self.text_encoder.device | |
| device = self.transformer.device | |
| print("Unloading vae") | |
| self.vae.to("cpu") | |
| self.text_encoder.to(device) | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| prompt, | |
| negative_prompt, | |
| height, | |
| width, | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| callback_on_step_end_tensor_inputs, | |
| ) | |
| self._guidance_scale = guidance_scale | |
| self._attention_kwargs = attention_kwargs | |
| self._current_timestep = None | |
| self._interrupt = False | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # 3. Encode input prompt | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| ) | |
| # unload text encoder | |
| print("Unloading text encoder") | |
| self.text_encoder.to("cpu") | |
| self.transformer.to(device) | |
| transformer_dtype = self.transformer.dtype | |
| prompt_embeds = prompt_embeds.to(device, transformer_dtype) | |
| if negative_prompt_embeds is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.to( | |
| device, transformer_dtype) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.transformer.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| num_frames, | |
| torch.float32, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Denoising loop | |
| num_warmup_steps = len(timesteps) - \ | |
| num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| self._current_timestep = t | |
| latent_model_input = latents.to(device, transformer_dtype) | |
| timestep = t.expand(latents.shape[0]) | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| attention_kwargs=attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| if self.do_classifier_free_guidance: | |
| noise_uncond = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| attention_kwargs=attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_uncond + guidance_scale * \ | |
| (noise_pred - noise_uncond) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end( | |
| self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop( | |
| "prompt_embeds", prompt_embeds) | |
| negative_prompt_embeds = callback_outputs.pop( | |
| "negative_prompt_embeds", negative_prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if XLA_AVAILABLE: | |
| xm.mark_step() | |
| self._current_timestep = None | |
| # unload transformer | |
| # load vae | |
| print("Loading Vae") | |
| self.vae.to(vae_device) | |
| if not output_type == "latent": | |
| latents = latents.to(self.vae.dtype) | |
| latents_mean = ( | |
| torch.tensor(self.vae.config.latents_mean) | |
| .view(1, self.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( | |
| latents.device, latents.dtype | |
| ) | |
| latents = latents / latents_std + latents_mean | |
| video = self.vae.decode(latents, return_dict=False)[0] | |
| video = self.video_processor.postprocess_video( | |
| video, output_type=output_type) | |
| else: | |
| video = latents | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (video,) | |
| return WanPipelineOutput(frames=video) | |
| class Wan21(BaseModel): | |
| arch = 'wan21' | |
| def __init__( | |
| self, | |
| device, | |
| model_config: ModelConfig, | |
| dtype='bf16', | |
| custom_pipeline=None, | |
| noise_scheduler=None, | |
| **kwargs | |
| ): | |
| super().__init__(device, model_config, dtype, | |
| custom_pipeline, noise_scheduler, **kwargs) | |
| self.is_flow_matching = True | |
| self.is_transformer = True | |
| self.target_lora_modules = ['WanTransformer3DModel'] | |
| # cache for holding noise | |
| self.effective_noise = None | |
| def get_bucket_divisibility(self): | |
| return 16 | |
| # static method to get the scheduler | |
| def get_train_scheduler(): | |
| scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) | |
| return scheduler | |
| def load_model(self): | |
| dtype = self.torch_dtype | |
| model_path = self.model_config.name_or_path | |
| self.print_and_status_update("Loading Wan2.1 model") | |
| subfolder = 'transformer' | |
| transformer_path = model_path | |
| if os.path.exists(transformer_path): | |
| subfolder = None | |
| transformer_path = os.path.join(transformer_path, 'transformer') | |
| te_path = self.model_config.extras_name_or_path | |
| if os.path.exists(os.path.join(model_path, 'text_encoder')): | |
| te_path = model_path | |
| vae_path = self.model_config.extras_name_or_path | |
| if os.path.exists(os.path.join(model_path, 'vae')): | |
| vae_path = model_path | |
| self.print_and_status_update("Loading transformer") | |
| transformer = WanTransformer3DModel.from_pretrained( | |
| transformer_path, | |
| subfolder=subfolder, | |
| torch_dtype=dtype, | |
| ).to(dtype=dtype) | |
| if self.model_config.split_model_over_gpus: | |
| raise ValueError( | |
| "Splitting model over gpus is not supported for Wan2.1 models") | |
| if not self.model_config.low_vram: | |
| # quantize on the device | |
| transformer.to(self.quantize_device, dtype=dtype) | |
| flush() | |
| if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: | |
| raise ValueError( | |
| "Assistant LoRA is not supported for Wan2.1 models currently") | |
| if self.model_config.lora_path is not None: | |
| raise ValueError( | |
| "Loading LoRA is not supported for Wan2.1 models currently") | |
| flush() | |
| if self.model_config.quantize: | |
| print("Quantizing Transformer") | |
| quantization_args = self.model_config.quantize_kwargs | |
| if 'exclude' not in quantization_args: | |
| quantization_args['exclude'] = [] | |
| # patch the state dict method | |
| patch_dequantization_on_save(transformer) | |
| quantization_type = get_qtype(self.model_config.qtype) | |
| self.print_and_status_update("Quantizing transformer") | |
| if self.model_config.low_vram: | |
| print("Quantizing blocks") | |
| orig_exclude = copy.deepcopy(quantization_args['exclude']) | |
| # quantize each block | |
| idx = 0 | |
| for block in tqdm(transformer.blocks): | |
| block.to(self.device_torch) | |
| quantize(block, weights=quantization_type, | |
| **quantization_args) | |
| freeze(block) | |
| idx += 1 | |
| flush() | |
| print("Quantizing the rest") | |
| low_vram_exclude = copy.deepcopy(quantization_args['exclude']) | |
| low_vram_exclude.append('blocks.*') | |
| quantization_args['exclude'] = low_vram_exclude | |
| # quantize the rest | |
| transformer.to(self.device_torch) | |
| quantize(transformer, weights=quantization_type, | |
| **quantization_args) | |
| quantization_args['exclude'] = orig_exclude | |
| else: | |
| # do it in one go | |
| quantize(transformer, weights=quantization_type, | |
| **quantization_args) | |
| freeze(transformer) | |
| # move it to the cpu for now | |
| transformer.to("cpu") | |
| else: | |
| transformer.to(self.device_torch, dtype=dtype) | |
| flush() | |
| self.print_and_status_update("Loading UMT5EncoderModel") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| te_path, subfolder="tokenizer", torch_dtype=dtype) | |
| text_encoder = UMT5EncoderModel.from_pretrained( | |
| te_path, subfolder="text_encoder", torch_dtype=dtype).to(dtype=dtype) | |
| text_encoder.to(self.device_torch, dtype=dtype) | |
| flush() | |
| if self.model_config.quantize_te: | |
| self.print_and_status_update("Quantizing UMT5EncoderModel") | |
| quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) | |
| freeze(text_encoder) | |
| flush() | |
| if self.model_config.low_vram: | |
| print("Moving transformer back to GPU") | |
| # we can move it back to the gpu now | |
| transformer.to(self.device_torch) | |
| scheduler = Wan21.get_train_scheduler() | |
| self.print_and_status_update("Loading VAE") | |
| # todo, example does float 32? check if quality suffers | |
| vae = AutoencoderKLWan.from_pretrained( | |
| vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype) | |
| flush() | |
| self.print_and_status_update("Making pipe") | |
| pipe: WanPipeline = WanPipeline( | |
| scheduler=scheduler, | |
| text_encoder=None, | |
| tokenizer=tokenizer, | |
| vae=vae, | |
| transformer=None, | |
| ) | |
| pipe.text_encoder = text_encoder | |
| pipe.transformer = transformer | |
| self.print_and_status_update("Preparing Model") | |
| text_encoder = pipe.text_encoder | |
| tokenizer = pipe.tokenizer | |
| pipe.transformer = pipe.transformer.to(self.device_torch) | |
| flush() | |
| text_encoder.to(self.device_torch) | |
| text_encoder.requires_grad_(False) | |
| text_encoder.eval() | |
| pipe.transformer = pipe.transformer.to(self.device_torch) | |
| flush() | |
| self.pipeline = pipe | |
| self.model = transformer | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.tokenizer = tokenizer | |
| def get_generation_pipeline(self): | |
| scheduler = UniPCMultistepScheduler(**scheduler_configUniPC) | |
| if self.model_config.low_vram: | |
| pipeline = AggressiveWanUnloadPipeline( | |
| vae=self.vae, | |
| transformer=self.model, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| scheduler=scheduler, | |
| device=self.device_torch | |
| ) | |
| else: | |
| pipeline = WanPipeline( | |
| vae=self.vae, | |
| transformer=self.unet, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| scheduler=scheduler, | |
| ) | |
| pipeline = pipeline.to(self.device_torch) | |
| return pipeline | |
| def generate_single_image( | |
| self, | |
| pipeline: WanPipeline, | |
| gen_config: GenerateImageConfig, | |
| conditional_embeds: PromptEmbeds, | |
| unconditional_embeds: PromptEmbeds, | |
| generator: torch.Generator, | |
| extra: dict, | |
| ): | |
| # reactivate progress bar since this is slooooow | |
| pipeline.set_progress_bar_config(disable=False) | |
| pipeline = pipeline.to(self.device_torch) | |
| # todo, figure out how to do video | |
| output = pipeline( | |
| prompt_embeds=conditional_embeds.text_embeds.to( | |
| self.device_torch, dtype=self.torch_dtype), | |
| negative_prompt_embeds=unconditional_embeds.text_embeds.to( | |
| self.device_torch, dtype=self.torch_dtype), | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| num_frames=gen_config.num_frames, | |
| generator=generator, | |
| return_dict=False, | |
| output_type="pil", | |
| **extra | |
| )[0] | |
| # shape = [1, frames, channels, height, width] | |
| batch_item = output[0] # list of pil images | |
| if gen_config.num_frames > 1: | |
| return batch_item # return the frames. | |
| else: | |
| # get just the first image | |
| img = batch_item[0] | |
| return img | |
| def get_noise_prediction( | |
| self, | |
| latent_model_input: torch.Tensor, | |
| timestep: torch.Tensor, # 0 to 1000 scale | |
| text_embeddings: PromptEmbeds, | |
| **kwargs | |
| ): | |
| # vae_scale_factor_spatial = 8 | |
| # vae_scale_factor_temporal = 4 | |
| # num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 | |
| # shape = ( | |
| # batch_size, | |
| # num_channels_latents, # 16 | |
| # num_latent_frames, # 81 | |
| # int(height) // self.vae_scale_factor_spatial, | |
| # int(width) // self.vae_scale_factor_spatial, | |
| # ) | |
| noise_pred = self.model( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=text_embeddings.text_embeds, | |
| return_dict=False, | |
| **kwargs | |
| )[0] | |
| return noise_pred | |
| def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: | |
| if self.pipeline.text_encoder.device != self.device_torch: | |
| self.pipeline.text_encoder.to(self.device_torch) | |
| prompt_embeds, _ = self.pipeline.encode_prompt( | |
| prompt, | |
| do_classifier_free_guidance=False, | |
| max_sequence_length=512, | |
| device=self.device_torch, | |
| dtype=self.torch_dtype, | |
| ) | |
| return PromptEmbeds(prompt_embeds) | |
| def encode_images( | |
| self, | |
| image_list: List[torch.Tensor], | |
| device=None, | |
| dtype=None | |
| ): | |
| if device is None: | |
| device = self.vae_device_torch | |
| if dtype is None: | |
| dtype = self.vae_torch_dtype | |
| if self.vae.device == 'cpu': | |
| self.vae.to(device) | |
| self.vae.eval() | |
| self.vae.requires_grad_(False) | |
| image_list = [image.to(device, dtype=dtype) for image in image_list] | |
| # Normalize shapes | |
| norm_images = [] | |
| for image in image_list: | |
| if image.ndim == 3: | |
| # (C, H, W) -> (C, 1, H, W) | |
| norm_images.append(image.unsqueeze(1)) | |
| elif image.ndim == 4: | |
| # (T, C, H, W) -> (C, T, H, W) | |
| norm_images.append(image.permute(1, 0, 2, 3)) | |
| else: | |
| raise ValueError(f"Invalid image shape: {image.shape}") | |
| # Stack to (B, C, T, H, W) | |
| images = torch.stack(norm_images) | |
| B, C, T, H, W = images.shape | |
| # Resize if needed (B * T, C, H, W) | |
| if H % 8 != 0 or W % 8 != 0: | |
| target_h = H // 8 * 8 | |
| target_w = W // 8 * 8 | |
| images = images.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) | |
| images = F.interpolate(images, size=(target_h, target_w), mode='bilinear', align_corners=False) | |
| images = images.view(B, T, C, target_h, target_w).permute(0, 2, 1, 3, 4) | |
| latents = self.vae.encode(images).latent_dist.sample() | |
| latents_mean = ( | |
| torch.tensor(self.vae.config.latents_mean) | |
| .view(1, self.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( | |
| latents.device, latents.dtype | |
| ) | |
| latents = (latents - latents_mean) * latents_std | |
| return latents.to(device, dtype=dtype) | |
| def get_model_has_grad(self): | |
| return self.model.proj_out.weight.requires_grad | |
| def get_te_has_grad(self): | |
| return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad | |
| def save_model(self, output_path, meta, save_dtype): | |
| # only save the unet | |
| transformer: Wan21 = unwrap_model(self.model) | |
| transformer.save_pretrained( | |
| save_directory=os.path.join(output_path, 'transformer'), | |
| safe_serialization=True, | |
| ) | |
| meta_path = os.path.join(output_path, 'aitk_meta.yaml') | |
| with open(meta_path, 'w') as f: | |
| yaml.dump(meta, f) | |
| def get_loss_target(self, *args, **kwargs): | |
| noise = kwargs.get('noise') | |
| batch = kwargs.get('batch') | |
| if batch is None: | |
| raise ValueError("Batch is not provided") | |
| if noise is None: | |
| raise ValueError("Noise is not provided") | |
| return (noise - batch.latents).detach() | |
| def convert_lora_weights_before_save(self, state_dict): | |
| return convert_to_original(state_dict) | |
| def convert_lora_weights_before_load(self, state_dict): | |
| return convert_to_diffusers(state_dict) | |
| def get_base_model_version(self): | |
| return "wan_2.1" | |