Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from einops import rearrange | |
| from omegaconf import DictConfig, ListConfig | |
| from torch import Tensor | |
| from common.config import create_object | |
| from common.decorators import log_on_entry, log_runtime | |
| from common.diffusion import ( | |
| classifier_free_guidance_dispatcher, | |
| create_sampler_from_config, | |
| create_sampling_timesteps_from_config, | |
| create_schedule_from_config, | |
| ) | |
| from common.distributed import ( | |
| get_device, | |
| get_global_rank, | |
| ) | |
| from common.distributed.meta_init_utils import ( | |
| meta_non_persistent_buffer_init_fn, | |
| ) | |
| # from common.fs import download | |
| from models.dit_v2 import na | |
| class VideoDiffusionInfer(): | |
| def __init__(self, config: DictConfig): | |
| self.config = config | |
| def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: | |
| t, h, w, c = latent.shape | |
| cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) | |
| if task == "t2v" or t == 1: | |
| # t2i or t2v generation. | |
| if task == "sr": | |
| cond[:, ..., :-1] = latent_blur[:] | |
| cond[:, ..., -1:] = 1.0 | |
| return cond | |
| if task == "i2v": | |
| # i2v generation. | |
| cond[:1, ..., :-1] = latent[:1] | |
| cond[:1, ..., -1:] = 1.0 | |
| return cond | |
| if task == "v2v": | |
| # v2v frame extension. | |
| cond[:2, ..., :-1] = latent[:2] | |
| cond[:2, ..., -1:] = 1.0 | |
| return cond | |
| if task == "sr": | |
| # sr generation. | |
| cond[:, ..., :-1] = latent_blur[:] | |
| cond[:, ..., -1:] = 1.0 | |
| return cond | |
| raise NotImplementedError | |
| def configure_dit_model(self, device="cpu", checkpoint=None): | |
| # Load dit checkpoint. | |
| # For fast init & resume, | |
| # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP. | |
| # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True. | |
| if self.config.dit.get("init_with_meta_device", False): | |
| init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta" | |
| else: | |
| init_device = "cpu" | |
| # Create dit model. | |
| with torch.device(init_device): | |
| self.dit = create_object(self.config.dit.model) | |
| self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint) | |
| if checkpoint: | |
| state = torch.load(checkpoint, map_location="cpu", mmap=True) | |
| loading_info = self.dit.load_state_dict(state, strict=True, assign=True) | |
| print(f"Loading pretrained ckpt from {checkpoint}") | |
| print(f"Loading info: {loading_info}") | |
| self.dit = meta_non_persistent_buffer_init_fn(self.dit) | |
| if device in [get_device(), "cuda"]: | |
| self.dit.to(get_device()) | |
| # Print model size. | |
| num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad) | |
| print(f"DiT trainable parameters: {num_params:,}") | |
| def configure_vae_model(self): | |
| # Create vae model. | |
| dtype = getattr(torch, self.config.vae.dtype) | |
| self.vae = create_object(self.config.vae.model) | |
| self.vae.requires_grad_(False).eval() | |
| self.vae.to(device=get_device(), dtype=dtype) | |
| # Load vae checkpoint. | |
| state = torch.load( | |
| self.config.vae.checkpoint, map_location=get_device(), mmap=True | |
| ) | |
| self.vae.load_state_dict(state) | |
| # Set causal slicing. | |
| if hasattr(self.vae, "set_causal_slicing") and hasattr(self.config.vae, "slicing"): | |
| self.vae.set_causal_slicing(**self.config.vae.slicing) | |
| # ------------------------------ Diffusion ------------------------------ # | |
| def configure_diffusion(self): | |
| self.schedule = create_schedule_from_config( | |
| config=self.config.diffusion.schedule, | |
| device=get_device(), | |
| ) | |
| self.sampling_timesteps = create_sampling_timesteps_from_config( | |
| config=self.config.diffusion.timesteps.sampling, | |
| schedule=self.schedule, | |
| device=get_device(), | |
| ) | |
| self.sampler = create_sampler_from_config( | |
| config=self.config.diffusion.sampler, | |
| schedule=self.schedule, | |
| timesteps=self.sampling_timesteps, | |
| ) | |
| # -------------------------------- Helper ------------------------------- # | |
| def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: | |
| use_sample = self.config.vae.get("use_sample", True) | |
| latents = [] | |
| if len(samples) > 0: | |
| device = get_device() | |
| dtype = getattr(torch, self.config.vae.dtype) | |
| scale = self.config.vae.scaling_factor | |
| shift = self.config.vae.get("shifting_factor", 0.0) | |
| if isinstance(scale, ListConfig): | |
| scale = torch.tensor(scale, device=device, dtype=dtype) | |
| if isinstance(shift, ListConfig): | |
| shift = torch.tensor(shift, device=device, dtype=dtype) | |
| # Group samples of the same shape to batches if enabled. | |
| if self.config.vae.grouping: | |
| batches, indices = na.pack(samples) | |
| else: | |
| batches = [sample.unsqueeze(0) for sample in samples] | |
| # Vae process by each group. | |
| for sample in batches: | |
| sample = sample.to(device, dtype) | |
| if hasattr(self.vae, "preprocess"): | |
| sample = self.vae.preprocess(sample) | |
| if use_sample: | |
| latent = self.vae.encode(sample).latent | |
| else: | |
| # Deterministic vae encode, only used for i2v inference (optionally) | |
| latent = self.vae.encode(sample).posterior.mode().squeeze(2) | |
| latent = latent.unsqueeze(2) if latent.ndim == 4 else latent | |
| latent = rearrange(latent, "b c ... -> b ... c") | |
| latent = (latent - shift) * scale | |
| latents.append(latent) | |
| # Ungroup back to individual latent with the original order. | |
| if self.config.vae.grouping: | |
| latents = na.unpack(latents, indices) | |
| else: | |
| latents = [latent.squeeze(0) for latent in latents] | |
| return latents | |
| def vae_decode(self, latents: List[Tensor]) -> List[Tensor]: | |
| samples = [] | |
| if len(latents) > 0: | |
| device = get_device() | |
| dtype = getattr(torch, self.config.vae.dtype) | |
| scale = self.config.vae.scaling_factor | |
| shift = self.config.vae.get("shifting_factor", 0.0) | |
| if isinstance(scale, ListConfig): | |
| scale = torch.tensor(scale, device=device, dtype=dtype) | |
| if isinstance(shift, ListConfig): | |
| shift = torch.tensor(shift, device=device, dtype=dtype) | |
| # Group latents of the same shape to batches if enabled. | |
| if self.config.vae.grouping: | |
| latents, indices = na.pack(latents) | |
| else: | |
| latents = [latent.unsqueeze(0) for latent in latents] | |
| # Vae process by each group. | |
| for latent in latents: | |
| latent = latent.to(device, dtype) | |
| latent = latent / scale + shift | |
| latent = rearrange(latent, "b ... c -> b c ...") | |
| latent = latent.squeeze(2) | |
| sample = self.vae.decode(latent).sample | |
| if hasattr(self.vae, "postprocess"): | |
| sample = self.vae.postprocess(sample) | |
| samples.append(sample) | |
| # Ungroup back to individual sample with the original order. | |
| if self.config.vae.grouping: | |
| samples = na.unpack(samples, indices) | |
| else: | |
| samples = [sample.squeeze(0) for sample in samples] | |
| return samples | |
| def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): | |
| # Skip if not needed. | |
| if not self.config.diffusion.timesteps.get("transform", False): | |
| return timesteps | |
| # Compute resolution. | |
| vt = self.config.vae.model.get("temporal_downsample_factor", 4) | |
| vs = self.config.vae.model.get("spatial_downsample_factor", 8) | |
| frames = (latents_shapes[:, 0] - 1) * vt + 1 | |
| heights = latents_shapes[:, 1] * vs | |
| widths = latents_shapes[:, 2] * vs | |
| # Compute shift factor. | |
| def get_lin_function(x1, y1, x2, y2): | |
| m = (y2 - y1) / (x2 - x1) | |
| b = y1 - m * x1 | |
| return lambda x: m * x + b | |
| img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) | |
| vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) | |
| shift = torch.where( | |
| frames > 1, | |
| vid_shift_fn(heights * widths * frames), | |
| img_shift_fn(heights * widths), | |
| ) | |
| # Shift timesteps. | |
| timesteps = timesteps / self.schedule.T | |
| timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) | |
| timesteps = timesteps * self.schedule.T | |
| return timesteps | |
| def inference( | |
| self, | |
| noises: List[Tensor], | |
| conditions: List[Tensor], | |
| texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], | |
| texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], | |
| cfg_scale: Optional[float] = None, | |
| dit_offload: bool = False, | |
| ) -> List[Tensor]: | |
| assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) | |
| batch_size = len(noises) | |
| # Return if empty. | |
| if batch_size == 0: | |
| return [] | |
| # Set cfg scale | |
| if cfg_scale is None: | |
| cfg_scale = self.config.diffusion.cfg.scale | |
| # Text embeddings. | |
| assert type(texts_pos[0]) is type(texts_neg[0]) | |
| if isinstance(texts_pos[0], str): | |
| text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) | |
| text_neg_embeds, text_neg_shapes = self.text_encode(texts_neg) | |
| elif isinstance(texts_pos[0], tuple): | |
| text_pos_embeds, text_pos_shapes = [], [] | |
| text_neg_embeds, text_neg_shapes = [], [] | |
| for pos in zip(*texts_pos): | |
| emb, shape = na.flatten(pos) | |
| text_pos_embeds.append(emb) | |
| text_pos_shapes.append(shape) | |
| for neg in zip(*texts_neg): | |
| emb, shape = na.flatten(neg) | |
| text_neg_embeds.append(emb) | |
| text_neg_shapes.append(shape) | |
| else: | |
| text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) | |
| text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) | |
| # Flatten. | |
| latents, latents_shapes = na.flatten(noises) | |
| latents_cond, _ = na.flatten(conditions) | |
| # Enter eval mode. | |
| was_training = self.dit.training | |
| self.dit.eval() | |
| # Sampling. | |
| latents = self.sampler.sample( | |
| x=latents, | |
| f=lambda args: classifier_free_guidance_dispatcher( | |
| pos=lambda: self.dit( | |
| vid=torch.cat([args.x_t, latents_cond], dim=-1), | |
| txt=text_pos_embeds, | |
| vid_shape=latents_shapes, | |
| txt_shape=text_pos_shapes, | |
| timestep=args.t.repeat(batch_size), | |
| ).vid_sample, | |
| neg=lambda: self.dit( | |
| vid=torch.cat([args.x_t, latents_cond], dim=-1), | |
| txt=text_neg_embeds, | |
| vid_shape=latents_shapes, | |
| txt_shape=text_neg_shapes, | |
| timestep=args.t.repeat(batch_size), | |
| ).vid_sample, | |
| scale=( | |
| cfg_scale | |
| if (args.i + 1) / len(self.sampler.timesteps) | |
| <= self.config.diffusion.cfg.get("partial", 1) | |
| else 1.0 | |
| ), | |
| rescale=self.config.diffusion.cfg.rescale, | |
| ), | |
| ) | |
| # Exit eval mode. | |
| self.dit.train(was_training) | |
| # Unflatten. | |
| latents = na.unflatten(latents, latents_shapes) | |
| if dit_offload: | |
| self.dit.to("cpu") | |
| # Vae decode. | |
| self.vae.to(get_device()) | |
| samples = self.vae_decode(latents) | |
| if dit_offload: | |
| self.dit.to(get_device()) | |
| return samples |