Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from enum import Enum | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from diffusers.models.normalization import RMSNorm | |
| from einops import rearrange | |
| from torch import Tensor, nn | |
| from common.logger import get_logger | |
| logger = get_logger(__name__) | |
| class MemoryState(Enum): | |
| """ | |
| State[Disabled]: No memory bank will be enabled. | |
| State[Initializing]: The model is handling the first clip, | |
| need to reset / initialize the memory bank. | |
| State[Active]: There has been some data in the memory bank. | |
| """ | |
| DISABLED = 0 | |
| INITIALIZING = 1 | |
| ACTIVE = 2 | |
| def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: | |
| if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): | |
| if x.ndim == 4: | |
| x = rearrange(x, "b c h w -> b h w c") | |
| x = norm_layer(x) | |
| x = rearrange(x, "b h w c -> b c h w") | |
| return x | |
| if x.ndim == 5: | |
| x = rearrange(x, "b c t h w -> b t h w c") | |
| x = norm_layer(x) | |
| x = rearrange(x, "b t h w c -> b c t h w") | |
| return x | |
| if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): | |
| if x.ndim <= 4: | |
| return norm_layer(x) | |
| if x.ndim == 5: | |
| t = x.size(2) | |
| x = rearrange(x, "b c t h w -> (b t) c h w") | |
| x = norm_layer(x) | |
| x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
| return x | |
| raise NotImplementedError | |
| def remove_head(tensor: Tensor, times: int = 1) -> Tensor: | |
| """ | |
| Remove duplicated first frame features in the up-sampling process. | |
| """ | |
| if times == 0: | |
| return tensor | |
| return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) | |
| def extend_head( | |
| tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None | |
| ) -> Tensor: | |
| """ | |
| When memory is None: | |
| - Duplicate first frame features in the down-sampling process. | |
| When memory is not None: | |
| - Concatenate memory features with the input features to keep temporal consistency. | |
| """ | |
| if times == 0: | |
| return tensor | |
| if memory is not None: | |
| return torch.cat((memory.to(tensor), tensor), dim=2) | |
| else: | |
| tile_repeat = np.ones(tensor.ndim).astype(int) | |
| tile_repeat[2] = times | |
| return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) | |
| def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): | |
| """ | |
| Inflate a 2D convolution weight matrix to a 3D one. | |
| Parameters: | |
| weight_2d: The weight matrix of 2D conv to be inflated. | |
| weight_3d: The weight matrix of 3D conv to be initialized. | |
| inflation_mode: the mode of inflation | |
| """ | |
| assert inflation_mode in ["constant", "replicate"] | |
| assert weight_3d.shape[:2] == weight_2d.shape[:2] | |
| with torch.no_grad(): | |
| if inflation_mode == "replicate": | |
| depth = weight_3d.size(2) | |
| weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) | |
| else: | |
| weight_3d.fill_(0.0) | |
| weight_3d[:, :, -1].copy_(weight_2d) | |
| return weight_3d | |
| def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): | |
| """ | |
| Inflate a 2D convolution bias tensor to a 3D one | |
| Parameters: | |
| bias_2d: The bias tensor of 2D conv to be inflated. | |
| bias_3d: The bias tensor of 3D conv to be initialized. | |
| inflation_mode: Placeholder to align `inflate_weight`. | |
| """ | |
| assert bias_3d.shape == bias_2d.shape | |
| with torch.no_grad(): | |
| bias_3d.copy_(bias_2d) | |
| return bias_3d | |
| def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): | |
| """ | |
| the main function to inflated 2D parameters to 3D. | |
| """ | |
| weight_name = prefix + "weight" | |
| bias_name = prefix + "bias" | |
| if weight_name in state_dict: | |
| weight_2d = state_dict[weight_name] | |
| if weight_2d.dim() == 4: | |
| # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) | |
| weight_3d = inflate_weight_fn( | |
| weight_2d=weight_2d, | |
| weight_3d=layer.weight, | |
| inflation_mode=layer.inflation_mode, | |
| ) | |
| state_dict[weight_name] = weight_3d | |
| else: | |
| return state_dict | |
| # It's a 3d state dict, should not do inflation on both bias and weight. | |
| if bias_name in state_dict: | |
| bias_2d = state_dict[bias_name] | |
| if bias_2d.dim() == 1: | |
| # Assuming the 2D biases are 1D tensors (out_channels,) | |
| bias_3d = inflate_bias_fn( | |
| bias_2d=bias_2d, | |
| bias_3d=layer.bias, | |
| inflation_mode=layer.inflation_mode, | |
| ) | |
| state_dict[bias_name] = bias_3d | |
| return state_dict | |