diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4adb8457c7aa61071ad2b9539f92d653d7a89c79 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +results/* \ No newline at end of file diff --git a/__pycache__/attn_ctrl.cpython-310.pyc b/__pycache__/attn_ctrl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b0d4e1805e5802ac1f66c07efe54cc0fdd02e1 Binary files /dev/null and b/__pycache__/attn_ctrl.cpython-310.pyc differ diff --git a/__pycache__/inference.cpython-310.pyc b/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c2e130911eb7e8d50b3ce892fd2c0af5dc48699 Binary files /dev/null and b/__pycache__/inference.cpython-310.pyc differ diff --git a/__pycache__/train.cpython-310.pyc b/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7937d9e16c2016afa7b03b1d0d459d1e224848d Binary files /dev/null and b/__pycache__/train.cpython-310.pyc differ diff --git a/app.py b/app.py index 0e3eba6a46de34bd76ae6a4e128ccdffb6d0c54f..0f26e174d0c1397cf81dd1d049b74c8ba97ed499 100644 --- a/app.py +++ b/app.py @@ -14,7 +14,7 @@ from inference import inference as inference_main def train_model(video, config): output_dir = 'results' os.makedirs(output_dir, exist_ok=True) - cur_save_dir = os.path.join(output_dir, str(len(os.listdir(output_dir))).zfill(2)) + cur_save_dir = os.path.join(output_dir, 'custom') config.dataset.single_video_path = video config.train.output_dir = cur_save_dir @@ -100,6 +100,12 @@ def update_preview_video(checkpoint_dir): if __name__ == "__main__": + + if os.path.exists('results/custom'): + os.system('rm -rf results/custom') + if os.path.exists('outputs'): + os.system('rm -rf outputs') + inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] default_motion_embeddings_combinations = ['down 1280','up 1280'] diff --git a/assets/train/car_turn.mp4 b/assets/train/car_turn.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..20d5ad3bb23695c846a400bb07712d75694c07fd Binary files /dev/null and b/assets/train/car_turn.mp4 differ diff --git a/assets/train/dolly_zoom_out.mp4 b/assets/train/dolly_zoom_out.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f77062ce40e626f7fc9b1719c1e87267be3564df Binary files /dev/null and b/assets/train/dolly_zoom_out.mp4 differ diff --git a/assets/train/orbit_shot.mp4 b/assets/train/orbit_shot.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d99b9b57f66fc697de8b5352fc2a50c1d95b4625 Binary files /dev/null and b/assets/train/orbit_shot.mp4 differ diff --git a/assets/train/pan_up.mp4 b/assets/train/pan_up.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a0dd11e35076ff21ed65adb76806d896671139f4 Binary files /dev/null and b/assets/train/pan_up.mp4 differ diff --git a/assets/train/run_up.mp4 b/assets/train/run_up.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7f3c4cfaaa1cea6acdc913468b274474cb6a4f93 Binary files /dev/null and b/assets/train/run_up.mp4 differ diff --git a/assets/train/santa_dance.mp4 b/assets/train/santa_dance.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d14e4609460e01aace2c1ea0a5f852a7b0c0490a Binary files /dev/null and b/assets/train/santa_dance.mp4 differ diff --git a/assets/train/train_ride.mp4 b/assets/train/train_ride.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9f5dc3f74639a9976f85f738e48099983f37e19e Binary files /dev/null and b/assets/train/train_ride.mp4 differ diff --git a/assets/train/walk.mp4 b/assets/train/walk.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a4f6529dce2b46dc9e0e8360b75d1929575fbbf Binary files /dev/null and b/assets/train/walk.mp4 differ diff --git a/attn_ctrl.py b/attn_ctrl.py new file mode 100644 index 0000000000000000000000000000000000000000..d59be00e37044502986b87ac016ec62204bc9401 --- /dev/null +++ b/attn_ctrl.py @@ -0,0 +1,264 @@ +import abc + +LOW_RESOURCE = False +import torch +import cv2 +import torch +import os +import numpy as np +from collections import defaultdict +from functools import partial +from typing import Any, Dict, Optional + +def register_attention_control(unet, config=None): + + def BasicTransformerBlock_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + # save the origin_hidden_states w/o pos_embed, for the use of motion v embedding + origin_hidden_states = None + if self.pos_embed is not None or hasattr(self.attn1,'vSpatial'): + origin_hidden_states = norm_hidden_states.clone() + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + # save the origin_hidden_states + origin_hidden_states = norm_hidden_states.clone() + norm_hidden_states = self.pos_embed(norm_hidden_states) + cross_attention_kwargs["origin_hidden_states"] = origin_hidden_states + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + # delete the origin_hidden_states + if cross_attention_kwargs is not None and "origin_hidden_states" in cross_attention_kwargs: + cross_attention_kwargs.pop("origin_hidden_states") + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + + def temp_attn_forward(self, additional_info=None): + to_out = self.to_out + if type(to_out) is torch.nn.modules.container.ModuleList: + to_out = self.to_out[0] + else: + to_out = self.to_out + + def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,origin_hidden_states=None): + + residual = hidden_states + + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + + # strategies to manipulate the motion value embedding + if additional_info is not None: + # empirically, in the inference stage of camera motion + # discarding the motion value embedding improves the text similarity of the generated video + if additional_info['removeMFromV']: + value = self.to_v(origin_hidden_states) + elif hasattr(self,'vSpatial'): + # during inference, the debiasing operation helps to generate more diverse videos + # refer to the 'Figure.3 Right' in the paper for more details + if additional_info['vSpatial_frameSubtraction']: + value = self.to_v(self.vSpatial.forward_frameSubtraction(origin_hidden_states)) + # during training, do not apply debias operation for motion learning + else: + value = self.to_v(self.vSpatial(origin_hidden_states)) + else: + value = self.to_v(origin_hidden_states) + else: + value = self.to_v(encoder_hidden_states) + + + query = self.head_to_batch_dim(query) + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) + + attention_probs = self.get_attention_scores(query, key, attention_mask) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = to_out(hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + return forward + + def register_recr(net_, count, name, config=None): + + if net_.__class__.__name__ == 'BasicTransformerBlock': + BasicTransformerBlock_forward_ = partial(BasicTransformerBlock_forward, net_) + net_.forward = BasicTransformerBlock_forward_ + + if net_.__class__.__name__ == 'Attention': + block_name = name.split('.attn')[0] + if config is not None and block_name in set([l.split('.attn')[0].split('.pos_embed')[0] for l in config.model.embedding_layers]): + additional_info = {} + additional_info['layer_name'] = name + additional_info['removeMFromV'] = config.strategy.get('removeMFromV', False) + additional_info['vSpatial_frameSubtraction'] = config.strategy.get('vSpatial_frameSubtraction', False) + net_.forward = temp_attn_forward(net_, additional_info) + print('register Motion V embedding at ', block_name) + return count + 1 + else: + return count + + elif hasattr(net_, 'children'): + for net_name, net__ in dict(net_.named_children()).items(): + count = register_recr(net__, count, name = name + '.' + net_name, config=config) + return count + + sub_nets = unet.named_children() + + for net in sub_nets: + register_recr(net[1], 0,name = net[0], config=config) + + + diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e4e548731f7a4132933f47ae285d4d94b3328ca --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,67 @@ +model: + type: unet + pretrained_model_path: cerspense/zeroscope_v2_576w + motion_embeddings: + combinations: + - - down + - 1280 + - - up + - 1280 + # unet can be either 'videoCrafter2' or 'zeroscope_v2_576w', the former produces better video quality + unet: videoCrafter2 + +train: + output_dir: ./results + validation_steps: 2000 + checkpointing_steps: 50 + checkpointing_start: 200 + train_batch_size: 1 + max_train_steps: 400 + gradient_accumulation_steps: 1 + cache_latents: true + cached_latent_dir: null + logger_type: tensorboard + mixed_precision: fp16 + use_8bit_adam: false + resume_from_checkpoint: null + resume_step: null + +dataset: + type: + - single_video + single_video_path: ./assets/car-roundabout-24.mp4 + single_video_prompt: 'A car turnaround in a city street' + width: 576 + height: 320 + n_sample_frames: 24 + fps: 8 + sample_start_idx: 1 + frame_step: 1 + use_bucketing: false + use_caption: false + +loss: + type: BaseLoss + learning_rate: 0.02 + lr_scheduler: constant + lr_warmup_steps: 0 + +noise_init: + type: BlendInit + noise_prior: 0.5 + +val: + prompt: + - "A skateboard slides along a city lane" + negative_prompt: "" + sample_preview: true + width: 576 + height: 320 + num_frames: 24 + num_inference_steps: 30 + guidance_scale: 12.0 + seeds: [0] + +strategy: + vSpatial_frameSubtraction: false + removeMFromV: false \ No newline at end of file diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..299b4027b126cc3bbc7c2e5735206f6ea9bd0c74 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,5 @@ +from .cached_dataset import CachedDataset +from .image_dataset import ImageDataset +from .single_video_dataset import SingleVideoDataset +from .video_folder_dataset import VideoFolderDataset +from .video_json_dataset import VideoJsonDataset \ No newline at end of file diff --git a/dataset/__pycache__/__init__.cpython-310.pyc b/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80044644b2264681fe79ade0fff3b5367711e34c Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/dataset/__pycache__/cached_dataset.cpython-310.pyc b/dataset/__pycache__/cached_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9287528e2134fc623386997711b4130e96389ba Binary files /dev/null and b/dataset/__pycache__/cached_dataset.cpython-310.pyc differ diff --git a/dataset/__pycache__/image_dataset.cpython-310.pyc b/dataset/__pycache__/image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d70055347c572313115dd737bf0c9e5e6910ff7 Binary files /dev/null and b/dataset/__pycache__/image_dataset.cpython-310.pyc differ diff --git a/dataset/__pycache__/single_video_dataset.cpython-310.pyc b/dataset/__pycache__/single_video_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e40d5b02e3e8c4a0b8e6ad420a7052567753e4 Binary files /dev/null and b/dataset/__pycache__/single_video_dataset.cpython-310.pyc differ diff --git a/dataset/__pycache__/video_folder_dataset.cpython-310.pyc b/dataset/__pycache__/video_folder_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13ab42c21e53794e662fbc9a4f4f45b14d771f20 Binary files /dev/null and b/dataset/__pycache__/video_folder_dataset.cpython-310.pyc differ diff --git a/dataset/__pycache__/video_json_dataset.cpython-310.pyc b/dataset/__pycache__/video_json_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d183e25c7308eb5307e9e1468ff6ffb85eef469d Binary files /dev/null and b/dataset/__pycache__/video_json_dataset.cpython-310.pyc differ diff --git a/dataset/cached_dataset.py b/dataset/cached_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4d52cd7b8b2b53dc36f4545cdad24f8d5e9086 --- /dev/null +++ b/dataset/cached_dataset.py @@ -0,0 +1,17 @@ +from utils.dataset_utils import * + +class CachedDataset(Dataset): + def __init__(self,cache_dir: str = ''): + self.cache_dir = cache_dir + self.cached_data_list = self.get_files_list() + + def get_files_list(self): + tensors_list = [f"{self.cache_dir}/{x}" for x in os.listdir(self.cache_dir) if x.endswith('.pt')] + return sorted(tensors_list) + + def __len__(self): + return len(self.cached_data_list) + + def __getitem__(self, index): + cached_latent = torch.load(self.cached_data_list[index], map_location='cuda:0') + return cached_latent \ No newline at end of file diff --git a/dataset/image_dataset.py b/dataset/image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2a061680ce4548e6846c68af6b4637739a12342e --- /dev/null +++ b/dataset/image_dataset.py @@ -0,0 +1,95 @@ +from utils.dataset_utils import * + +class ImageDataset(Dataset): + + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + base_width: int = 256, + base_height: int = 256, + use_caption: bool = False, + image_dir: str = '', + single_img_prompt: str = '', + use_bucketing: bool = False, + fallback_prompt: str = '', + **kwargs + ): + self.tokenizer = tokenizer + self.img_types = (".png", ".jpg", ".jpeg", '.bmp') + self.use_bucketing = use_bucketing + + self.image_dir = self.get_images_list(image_dir) + self.fallback_prompt = fallback_prompt + + self.use_caption = use_caption + self.single_img_prompt = single_img_prompt + + self.width = width + self.height = height + + def get_images_list(self, image_dir): + if os.path.exists(image_dir): + imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)] + full_img_dir = [] + + for img in imgs: + full_img_dir.append(f"{image_dir}/{img}") + + return sorted(full_img_dir) + + return [''] + + def image_batch(self, index): + train_data = self.image_dir[index] + img = train_data + + try: + img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB) + except: + img = T.transforms.PILToTensor()(Image.open(img).convert("RGB")) + + width = self.width + height = self.height + + if self.use_bucketing: + _, h, w = img.shape + width, height = sensible_buckets(width, height, w, h) + + resize = T.transforms.Resize((height, width), antialias=True) + + img = resize(img) + img = repeat(img, 'c h w -> f c h w', f=16) + + prompt = get_text_prompt( + file_path=train_data, + text_prompt=self.single_img_prompt, + fallback_prompt=self.fallback_prompt, + ext_types=self.img_types, + use_caption=True + ) + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return img, prompt, prompt_ids + + @staticmethod + def __getname__(): return 'image' + + def __len__(self): + # Image directory + if os.path.exists(self.image_dir[0]): + return len(self.image_dir) + else: + return 0 + + def __getitem__(self, index): + img, prompt, prompt_ids = self.image_batch(index) + example = { + "pixel_values": (img / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/dataset/single_video_dataset.py b/dataset/single_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..65584e0c1d1aa2910326e578d353600b00cc45d1 --- /dev/null +++ b/dataset/single_video_dataset.py @@ -0,0 +1,106 @@ +from utils.dataset_utils import * + +class SingleVideoDataset(Dataset): + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 4, + frame_step: int = 1, + single_video_path: str = "", + single_video_prompt: str = "", + use_caption: bool = False, + use_bucketing: bool = False, + **kwargs + ): + self.tokenizer = tokenizer + self.use_bucketing = use_bucketing + self.frames = [] + self.index = 1 + + self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") + self.n_sample_frames = n_sample_frames + self.frame_step = frame_step + + self.single_video_path = single_video_path + self.single_video_prompt = single_video_prompt + + self.width = width + self.height = height + + def create_video_chunks(self): + vr = decord.VideoReader(self.single_video_path) + vr_range = range(0, len(vr), self.frame_step) + + self.frames = list(self.chunk(vr_range, self.n_sample_frames)) + return self.frames + + def chunk(self, it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + def get_frame_batch(self, vr, resize=None): + index = self.index + frames = vr.get_batch(self.frames[self.index]) + + if type(frames) == decord.ndarray.NDArray: + frames = torch.from_numpy(frames.asnumpy()) + + video = rearrange(frames, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video + + def get_frame_buckets(self, vr): + h, w, c = vr[0].shape + width, height = sensible_buckets(self.width, self.height, w, h) + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + + return video, vr + + def single_video_batch(self, index): + train_data = self.single_video_path + self.index = index + + if train_data.endswith(self.vid_types): + video, _ = self.process_video_wrapper(train_data) + + prompt = self.single_video_prompt + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + else: + raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") + + @staticmethod + def __getname__(): return 'single_video' + + def __len__(self): + + return len(self.create_video_chunks()) + + def __getitem__(self, index): + + video, prompt, prompt_ids = self.single_video_batch(index) + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/dataset/video_folder_dataset.py b/dataset/video_folder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d049f40987d3ee332931e1b03ece2f32158078c --- /dev/null +++ b/dataset/video_folder_dataset.py @@ -0,0 +1,90 @@ +from utils.dataset_utils import * + +class VideoFolderDataset(Dataset): + def __init__( + self, + tokenizer=None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 16, + fps: int = 8, + path: str = "./data", + fallback_prompt: str = "", + use_bucketing: bool = False, + **kwargs + ): + self.tokenizer = tokenizer + self.use_bucketing = use_bucketing + + self.fallback_prompt = fallback_prompt + + self.video_files = glob(f"{path}/*.mp4") + + self.width = width + self.height = height + + self.n_sample_frames = n_sample_frames + self.fps = fps + + def get_frame_buckets(self, vr): + h, w, c = vr[0].shape + width, height = sensible_buckets(self.width, self.height, w, h) + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def get_frame_batch(self, vr, resize=None): + n_sample_frames = self.n_sample_frames + native_fps = vr.get_avg_fps() + + every_nth_frame = max(1, round(native_fps / self.fps)) + every_nth_frame = min(len(vr), every_nth_frame) + + effective_length = len(vr) // every_nth_frame + if effective_length < n_sample_frames: + n_sample_frames = effective_length + + effective_idx = random.randint(0, (effective_length - n_sample_frames)) + idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames) + + video = vr.get_batch(idxs) + video = rearrange(video, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video, vr + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + return video, vr + + def get_prompt_ids(self, prompt): + return self.tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + @staticmethod + def __getname__(): return 'folder' + + def __len__(self): + return len(self.video_files) + + def __getitem__(self, index): + + video, _ = self.process_video_wrapper(self.video_files[index]) + + prompt = self.fallback_prompt + + prompt_ids = self.get_prompt_ids(prompt) + + return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()} \ No newline at end of file diff --git a/dataset/video_json_dataset.py b/dataset/video_json_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9b693c25b566d02615b79deea830e6c75f060aba --- /dev/null +++ b/dataset/video_json_dataset.py @@ -0,0 +1,183 @@ +from utils.dataset_utils import * + +# https://github.com/ExponentialML/Video-BLIP2-Preprocessor +class VideoJsonDataset(Dataset): + def __init__( + self, + tokenizer = None, + width: int = 256, + height: int = 256, + n_sample_frames: int = 4, + sample_start_idx: int = 1, + frame_step: int = 1, + json_path: str ="", + json_data = None, + vid_data_key: str = "video_path", + preprocessed: bool = False, + use_bucketing: bool = False, + **kwargs + ): + self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") + self.use_bucketing = use_bucketing + self.tokenizer = tokenizer + self.preprocessed = preprocessed + + self.vid_data_key = vid_data_key + self.train_data = self.load_from_json(json_path, json_data) + + self.width = width + self.height = height + + self.n_sample_frames = n_sample_frames + self.sample_start_idx = sample_start_idx + self.frame_step = frame_step + + def build_json(self, json_data): + extended_data = [] + for data in json_data['data']: + for nested_data in data['data']: + self.build_json_dict( + data, + nested_data, + extended_data + ) + json_data = extended_data + return json_data + + def build_json_dict(self, data, nested_data, extended_data): + clip_path = nested_data['clip_path'] if 'clip_path' in nested_data else None + + extended_data.append({ + self.vid_data_key: data[self.vid_data_key], + 'frame_index': nested_data['frame_index'], + 'prompt': nested_data['prompt'], + 'clip_path': clip_path + }) + + def load_from_json(self, path, json_data): + try: + with open(path) as jpath: + print(f"Loading JSON from {path}") + json_data = json.load(jpath) + + return self.build_json(json_data) + + except: + self.train_data = [] + print("Non-existant JSON path. Skipping.") + + def validate_json(self, base_path, path): + return os.path.exists(f"{base_path}/{path}") + + def get_frame_range(self, vr): + return get_video_frames( + vr, + self.sample_start_idx, + self.frame_step, + self.n_sample_frames + ) + + def get_vid_idx(self, vr, vid_data=None): + frames = self.n_sample_frames + + if vid_data is not None: + idx = vid_data['frame_index'] + else: + idx = self.sample_start_idx + + return idx + + def get_frame_buckets(self, vr): + _, h, w = vr[0].shape + width, height = sensible_buckets(self.width, self.height, h, w) + # width, height = self.width, self.height + resize = T.transforms.Resize((height, width), antialias=True) + + return resize + + def get_frame_batch(self, vr, resize=None): + frame_range = self.get_frame_range(vr) + frames = vr.get_batch(frame_range) + video = rearrange(frames, "f h w c -> f c h w") + + if resize is not None: video = resize(video) + return video + + def process_video_wrapper(self, vid_path): + video, vr = process_video( + vid_path, + self.use_bucketing, + self.width, + self.height, + self.get_frame_buckets, + self.get_frame_batch + ) + + return video, vr + + def train_data_batch(self, index): + + # If we are training on individual clips. + if 'clip_path' in self.train_data[index] and \ + self.train_data[index]['clip_path'] is not None: + + vid_data = self.train_data[index] + + clip_path = vid_data['clip_path'] + + # Get video prompt + prompt = vid_data['prompt'] + + video, _ = self.process_video_wrapper(clip_path) + + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + + # Assign train data + train_data = self.train_data[index] + + # Get the frame of the current index. + self.sample_start_idx = train_data['frame_index'] + + # Initialize resize + resize = None + + video, vr = self.process_video_wrapper(train_data[self.vid_data_key]) + + # Get video prompt + prompt = train_data['prompt'] + vr.seek(0) + + prompt_ids = get_prompt_ids(prompt, self.tokenizer) + + return video, prompt, prompt_ids + + @staticmethod + def __getname__(): return 'json' + + def __len__(self): + if self.train_data is not None: + return len(self.train_data) + else: + return 0 + + def __getitem__(self, index): + + # Initialize variables + video = None + prompt = None + prompt_ids = None + + # Use default JSON training + if self.train_data is not None: + video, prompt, prompt_ids = self.train_data_batch(index) + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": prompt_ids[0], + "text_prompt": prompt, + 'dataset': self.__getname__() + } + + return example \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4d9b5ca5eee78f991e5b27bffd2fa4e80017ac --- /dev/null +++ b/inference.py @@ -0,0 +1,133 @@ +import torch +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +from train import export_to_video +from models.unet.motion_embeddings import load_motion_embeddings +from noise_init.blend_init import BlendInit +from noise_init.blend_freq_init import BlendFreqInit +from noise_init.fft_init import FFTInit +from noise_init.freq_init import FreqInit +from attn_ctrl import register_attention_control +import numpy as np +import os +from omegaconf import OmegaConf + +def get_pipe(embedding_dir='baseline',config=None,noisy_latent=None, video_round=None): + + # load video generation model + pipe = DiffusionPipeline.from_pretrained(config.model.pretrained_model_path,torch_dtype=torch.float16) + + # use videocrafterv2 unet + if config.model.unet == 'videoCrafter2': + from models.unet.unet_3d_condition import UNet3DConditionModel + # unet = UNet3DConditionModel.from_pretrained("adamdad/videocrafterv2_diffusers",subfolder='unet',torch_dtype=torch.float16) + unet = UNet3DConditionModel.from_pretrained("adamdad/videocrafterv2_diffusers",torch_dtype=torch.float16) + pipe.unet = unet + + # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + + # memory optimization + pipe.enable_vae_slicing() + + # if 'vanilla' not in embedding_dir: + + noisy_latent = torch.load(f'{embedding_dir}/cached_latents/cached_0.pt')['inversion_noise'][None,] + if video_round is None: + motion_embed = torch.load(f'{embedding_dir}/motion_embed.pt') + else: + motion_embed = torch.load(f'{embedding_dir}/{video_round}/motion_embed.pt') + load_motion_embeddings( + pipe.unet, + motion_embed, + ) + config.model['embedding_layers'] = list(motion_embed.keys()) + + return pipe, config, noisy_latent + +def inference(embedding_dir='vanilla', + video_round=None, + prompt=None, + save_dir=None, + seed=None, + motion_type=None, + inference_steps=30 + ): + + # check motion type is valid + if motion_type != 'camera' and \ + motion_type != 'object' and \ + motion_type != 'hybrid': + raise ValueError('Invalid motion type') + + if seed is None: + seed = 0 + + # load motion embedding + noisy_latent = None + + config = OmegaConf.load(f'{embedding_dir}/config.yaml') + + + # different motion type assigns different strategy + if motion_type == 'camera': + config['strategy']['removeMFromV'] = True + + elif motion_type == 'object' or motion_type == 'hybrid': + config['strategy']['vSpatial_frameSubtraction'] = True + + + pipe, config, noisy_latent = get_pipe(embedding_dir=embedding_dir,config=config,noisy_latent=noisy_latent,video_round=video_round) + n_frames = config.val.num_frames + + shape = (config.val.height,config.val.width) + os.makedirs(save_dir,exist_ok=True) + + + cur_save_dir = f'{save_dir}/{"_".join(prompt.split())}.mp4' + + register_attention_control(pipe.unet,config=config) + + if noisy_latent is not None: + torch.manual_seed(seed) + noise = torch.randn_like(noisy_latent) + init_noise = BlendInit(noisy_latent, noise, noise_prior=0.5) + else: + init_noise = None + + input_init_noise = init_noise.clone() if not init_noise is None else None + video_frames = pipe( + prompt=prompt, + num_inference_steps=inference_steps, + guidance_scale=12, + height=shape[0], + width=shape[1], + num_frames=n_frames, + generator=torch.Generator("cuda").manual_seed(seed), + latents=input_init_noise, + ).frames[0] + + video_path = export_to_video(video_frames,output_video_path=cur_save_dir,fps=8) + + return video_path + + +if __name__ =="__main__": + + prompts = ["A skateboard slides along a city lane", + "A tank is running in the desert.", + "A toy train chugs around a roundabout tree"] + + + embedding_dir = './results' + video_round = 'checkpoint-250' + save_dir = f'outputs' + + inference( + embedding_dir=embedding_dir, + prompt=prompts, + video_round=video_round, + save_dir=save_dir, + motion_type='hybrid', + seed=100 + ) + diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3881a220dd37fed32d2d9ea400f92cf3a0bf62e3 --- /dev/null +++ b/loss/__init__.py @@ -0,0 +1,4 @@ +from .base_loss import BaseLoss +from .debiased_hybrid_loss import DebiasedHybridLoss +from .debiased_temporal_loss import DebiasedTemporalLoss +from .motion_distillation_loss import MotionDistillationLoss diff --git a/loss/__pycache__/__init__.cpython-310.pyc b/loss/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71fa6d8f4ec56f37a805d31f4efb6e3eaed31646 Binary files /dev/null and b/loss/__pycache__/__init__.cpython-310.pyc differ diff --git a/loss/__pycache__/base_loss.cpython-310.pyc b/loss/__pycache__/base_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c13ffe1163662aaa63f1fbb8588fdd7cfc8e2e0 Binary files /dev/null and b/loss/__pycache__/base_loss.cpython-310.pyc differ diff --git a/loss/__pycache__/debiased_hybrid_loss.cpython-310.pyc b/loss/__pycache__/debiased_hybrid_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86651f2d8886a5425cc5f26b1293880f211d28c1 Binary files /dev/null and b/loss/__pycache__/debiased_hybrid_loss.cpython-310.pyc differ diff --git a/loss/__pycache__/debiased_temporal_loss.cpython-310.pyc b/loss/__pycache__/debiased_temporal_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b3f8fca935fe1006937482c779fcf1804f3597 Binary files /dev/null and b/loss/__pycache__/debiased_temporal_loss.cpython-310.pyc differ diff --git a/loss/__pycache__/motion_distillation_loss.cpython-310.pyc b/loss/__pycache__/motion_distillation_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e0b97e6145db83b1af1f0eb1fbfd78f368210c Binary files /dev/null and b/loss/__pycache__/motion_distillation_loss.cpython-310.pyc differ diff --git a/loss/base_loss.py b/loss/base_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..10fe1dc0f953e116340c6aae7b0295377939bc33 --- /dev/null +++ b/loss/base_loss.py @@ -0,0 +1,75 @@ +import torch +import torch.nn.functional as F +from utils.func_utils import tensor_to_vae_latent, sample_noise + +def BaseLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + cache_latents = config.train.cache_latents + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + accelerator.backward(loss_temporal) + optimizers[0].step() + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal + diff --git a/loss/debiased_hybrid_loss.py b/loss/debiased_hybrid_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec910a505a75c3ba96126927682ee13d802a0c0 --- /dev/null +++ b/loss/debiased_hybrid_loss.py @@ -0,0 +1,149 @@ +import torch +from torchvision import transforms +import torch.nn.functional as F +import random + +from utils.lora import extract_lora_child_module +from utils.func_utils import tensor_to_vae_latent, sample_noise + +def DebiasedHybridLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config, + random_hflip_img=False, + spatial_lora_num=1 + ): + mask_spatial_lora = random.uniform(0, 1) < 0.2 + cache_latents = config.train.cache_latents + + + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + if mask_spatial_lora: + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + for lora_i in loras: + lora_i.scale = 0. + loss_spatial = None + else: + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + + if spatial_lora_num == 1: + for lora_i in loras: + lora_i.scale = 1. + else: + for lora_i in loras: + lora_i.scale = 0. + + for lora_idx in range(0, len(loras), spatial_lora_num): + loras[lora_idx + step].scale = 1. + + loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"]) + if len(loras) > 0: + for lora_i in loras: + lora_i.scale = 0. + + ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item() + + if random.uniform(0, 1) < random_hflip_img: + pixel_values_spatial = transforms.functional.hflip( + batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1) + latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae) + noise_spatial = sample_noise(latents_spatial, 0.1, False) + noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps) + target_spatial = noise_spatial + model_pred_spatial = unet(noisy_latents_input, timesteps, + encoder_hidden_states=encoder_hidden_states).sample + loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), + target_spatial[:, :, 0, :, :].float(), reduction="mean") + else: + noisy_latents_input = noisy_latents[:, :, ran_idx, :, :] + target_spatial = target[:, :, ran_idx, :, :] + model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps, + encoder_hidden_states=encoder_hidden_states).sample + loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), + target_spatial.float(), reduction="mean") + + + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + beta = 1 + alpha = (beta ** 2 + 1) ** 0.5 + ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item() + model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2) + target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2) + loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean") + loss_temporal = loss_temporal + loss_ad_temporal + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + if not mask_spatial_lora: + accelerator.backward(loss_spatial, retain_graph=True) + if spatial_lora_num == 1: + optimizers[1].step() + else: + optimizers[step+1].step() + + accelerator.backward(loss_temporal) + optimizers[0].step() + + if spatial_lora_num == 1: + lr_schedulers[1].step() + else: + lr_schedulers[1 + step].step() + + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal \ No newline at end of file diff --git a/loss/debiased_temporal_loss.py b/loss/debiased_temporal_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..41ba6459fefd89e2893dd3ff701f5d59a9603d1a --- /dev/null +++ b/loss/debiased_temporal_loss.py @@ -0,0 +1,86 @@ +import torch +import torch.nn.functional as F + +from utils.func_utils import tensor_to_vae_latent, sample_noise + +def DebiasedTemporalLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + cache_latents = config.train.cache_latents + + + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + beta = 1 + alpha = (beta ** 2 + 1) ** 0.5 + ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item() + model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2) + target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2) + loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean") + loss_temporal = loss_temporal + loss_ad_temporal + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + accelerator.backward(loss_temporal) + optimizers[0].step() + + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal \ No newline at end of file diff --git a/loss/motion_distillation_loss.py b/loss/motion_distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..968bcc5749121445518764f56b49c46f7985d8a8 --- /dev/null +++ b/loss/motion_distillation_loss.py @@ -0,0 +1,79 @@ +import torch +import torch.nn.functional as F +from utils.func_utils import tensor_to_vae_latent, sample_noise + +def MotionDistillationLoss( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ): + cache_latents = config.train.cache_latents + + if not cache_latents: + latents = tensor_to_vae_latent(batch["pixel_values"], vae) + else: + latents = batch["latents"] + + # Sample noise that we'll add to the latents + # use_offset_noise = use_offset_noise and not rescale_schedule + + noise = sample_noise(latents, 0.1, False) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # *Potentially* Fixes gradient checkpointing training. + # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + # if kwargs.get('eval_train', False): + # unet.eval() + # text_encoder.eval() + + # Encode text embeddings + token_ids = batch['prompt_ids'] + encoder_hidden_states = text_encoder(token_ids)[0] + detached_encoder_state = encoder_hidden_states.clone().detach() + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + encoder_hidden_states = detached_encoder_state + + + # optimization + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + + loss_temporal = 0 + model_pred_reidual = torch.abs(model_pred[:,:,1:,:,:] - model_pred[:,:,:-1,:,:]) + target_residual = torch.abs(target[:, :, 1:, :, :] - target[:, :, :-1, :, :]) + loss_temporal = loss_temporal + (1 - F.cosine_similarity(model_pred_reidual, target_residual, dim=2).mean) + + avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean() + train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps + + accelerator.backward(loss_temporal) + optimizers[0].step() + lr_schedulers[0].step() + + return loss_temporal, train_loss_temporal + diff --git a/models/dit/latte_t2v.py b/models/dit/latte_t2v.py new file mode 100644 index 0000000000000000000000000000000000000000..fc96a3040c341bc881c81c1c5e9eb1b1f215f06d --- /dev/null +++ b/models/dit/latte_t2v.py @@ -0,0 +1,990 @@ +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from typing import Any, Dict, Optional, Tuple +from diffusers.models import Transformer2DModel +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection, PatchEmbed, CombinedTimestepSizeEmbeddings +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero +from diffusers.models.attention_processor import Attention +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states + +@maybe_allow_in_graph +class BasicTransformerBlock_(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # # 2. Cross-Attn + # if cross_attention_dim is not None or double_self_attention: + # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # # the second cross attention block. + # self.norm2 = ( + # AdaLayerNorm(dim, num_embeds_ada_norm) + # if self.use_ada_layer_norm + # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + # ) + # self.attn2 = Attention( + # query_dim=dim, + # cross_attention_dim=cross_attention_dim if not double_self_attention else None, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) # is self-attn if encoder_hidden_states is none + # else: + # self.norm2 = None + # self.attn2 = None + + # 3. Feed-forward + # if not self.use_ada_layer_norm_single: + # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # # 3. Cross-Attention + # if self.attn2 is not None: + # if self.use_ada_layer_norm: + # norm_hidden_states = self.norm2(hidden_states, timestep) + # elif self.use_ada_layer_norm_zero or self.use_layer_norm: + # norm_hidden_states = self.norm2(hidden_states) + # elif self.use_ada_layer_norm_single: + # # For PixArt norm2 isn't applied here: + # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + # norm_hidden_states = hidden_states + # else: + # raise ValueError("Incorrect norm") + + # if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + # norm_hidden_states = self.pos_embed(norm_hidden_states) + + # attn_output = self.attn2( + # norm_hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # attention_mask=encoder_attention_mask, + # **cross_attention_kwargs, + # ) + # hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # if not self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + # norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = self.norm3(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = CombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + batch_size: int = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class LatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + # define temporal positional embedding + temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num + hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous() + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous() + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', f=frame).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1) + + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_patches: # here + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + # batch_size = hidden_states.shape[0] + batch_size = input_batch_size + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + else: + encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', f=frame).contiguous() + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous() + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + ) + + if enable_temporal_attentions: + + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + else: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous() + + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous() + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() + + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + def get_1d_sincos_temp_embed(self, embed_dim, length): + pos = torch.arange(0, length).unsqueeze(1) + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + + # model_files = [ + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.bin'), + # os.path.join(pretrained_model_path, 'diffusion_pytorch_model.safetensors') + # ] + + # model_file = None + + # for fp in model_files: + # if os.path.exists(fp): + # model_file = fp + + # if not model_file: + # raise RuntimeError(f"{model_file} does not exist") + + # if model_file.split(".")[-1] == "safetensors": + # from safetensors import safe_open + # state_dict = {} + # with safe_open(model_file, framework="pt", device="cpu") as f: + # for key in f.keys(): + # state_dict[key] = f.get_tensor(key) + # else: + # state_dict = torch.load(model_file, map_location="cpu") + + # for k, v in model.state_dict().items(): + # if 'temporal_transformer_blocks' in k: + # state_dict.update({k: v}) + + # model.load_state_dict(state_dict) + + return model \ No newline at end of file diff --git a/models/unet/__pycache__/motion_embeddings.cpython-310.pyc b/models/unet/__pycache__/motion_embeddings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a635fe4425617ac036c56f7dc701e0f58973f24 Binary files /dev/null and b/models/unet/__pycache__/motion_embeddings.cpython-310.pyc differ diff --git a/models/unet/__pycache__/unet_3d_blocks.cpython-310.pyc b/models/unet/__pycache__/unet_3d_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0377d25ccc4c684bb5d29a62146f2f8847682eed Binary files /dev/null and b/models/unet/__pycache__/unet_3d_blocks.cpython-310.pyc differ diff --git a/models/unet/__pycache__/unet_3d_condition.cpython-310.pyc b/models/unet/__pycache__/unet_3d_condition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7972674d04c11f8ccacfa048544ec863ca67a2b5 Binary files /dev/null and b/models/unet/__pycache__/unet_3d_condition.cpython-310.pyc differ diff --git a/models/unet/motion_embeddings.py b/models/unet/motion_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c757166536b16199d01fe569fb63f67e0dc7bf --- /dev/null +++ b/models/unet/motion_embeddings.py @@ -0,0 +1,283 @@ +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +class MotionEmbedding(nn.Module): + + def __init__(self, embed_dim: int = None, max_seq_length: int = 32, wh: int = 1): + super().__init__() + self.embed = nn.Parameter(torch.zeros(wh, max_seq_length, embed_dim)) + print('register spatial motion embedding with', wh) + + self.scale = 1.0 + self.trained_length = -1 + + def set_scale(self, scale: float): + self.scale = scale + + def set_lengths(self, trained_length: int): + if trained_length > self.embed.shape[1] or trained_length <= 0: + raise ValueError("Trained length is out of bounds") + self.trained_length = trained_length + + def forward(self, x): + _, seq_length, _ = x.shape # seq_length here is the target sequence length for x + # print('seq_length',seq_length) + # Assuming self.embed is [batch, frames, dim] + embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic + + # Check if interpolation is needed + if self.trained_length != -1 and seq_length != self.trained_length: + # Interpolate embeddings to match x's sequence length + # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames + embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] + embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) + embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] + + # Ensure the interpolated embeddings match the sequence length of x + if embeddings.shape[1] != seq_length: + raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") + + if x.shape[0] != embeddings.shape[0]: + x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale + else: + # Now embeddings should have the shape [batch, seq_length, dim] matching x + x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions + + return x + + + def forward_average(self, x): + _, seq_length, _ = x.shape # seq_length here is the target sequence length for x + # print('seq_length',seq_length) + # Assuming self.embed is [batch, frames, dim] + embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic + + # Check if interpolation is needed + if self.trained_length != -1 and seq_length != self.trained_length: + # Interpolate embeddings to match x's sequence length + # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames + embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] + embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) + embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] + + # Ensure the interpolated embeddings match the sequence length of x + if embeddings.shape[1] != seq_length: + raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") + + embeddings_mean = embeddings.mean(dim=1, keepdim=True) + embeddings = embeddings - embeddings_mean + if x.shape[0] != embeddings.shape[0]: + x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale + else: + # Now embeddings should have the shape [batch, seq_length, dim] matching x + x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions + + return x + + def forward_frameSubtraction(self, x): + _, seq_length, _ = x.shape # seq_length here is the target sequence length for x + # print('seq_length',seq_length) + # Assuming self.embed is [batch, frames, dim] + embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic + + # Check if interpolation is needed + if self.trained_length != -1 and seq_length != self.trained_length: + # Interpolate embeddings to match x's sequence length + # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames + embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] + embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) + embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] + + # Ensure the interpolated embeddings match the sequence length of x + if embeddings.shape[1] != seq_length: + raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") + + embeddings_subtraction = embeddings[:,1:] - embeddings[:,:-1] + + embeddings = embeddings.clone().detach() + embeddings[:,1:] = embeddings_subtraction + + # first frame minus mean + # embeddings[:,0:1] = embeddings[:,0:1] - embeddings.mean(dim=1, keepdim=True) + + if x.shape[0] != embeddings.shape[0]: + x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale + else: + # Now embeddings should have the shape [batch, seq_length, dim] matching x + x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions + + return x + +class MotionEmbeddingSpatial(nn.Module): + + def __init__(self, h: int = None, w: int = None, embed_dim: int = None, max_seq_length: int = 32): + super().__init__() + self.embed = nn.Parameter(torch.zeros(h*w, max_seq_length, embed_dim)) + self.scale = 1.0 + self.trained_length = -1 + + def set_scale(self, scale: float): + self.scale = scale + + def set_lengths(self, trained_length: int): + if trained_length > self.embed.shape[1] or trained_length <= 0: + raise ValueError("Trained length is out of bounds") + self.trained_length = trained_length + + def forward(self, x): + _, seq_length, _ = x.shape # seq_length here is the target sequence length for x + + # Assuming self.embed is [batch, frames, dim] + embeddings = self.embed[:, :seq_length] # Initial slice, may not be necessary depending on the interpolation logic + + # Check if interpolation is needed + if self.trained_length != -1 and seq_length != self.trained_length: + # Interpolate embeddings to match x's sequence length + # Ensure embeddings is [batch, dim, frames] for 1D interpolation across frames + embeddings = embeddings.permute(0, 2, 1) # Now [batch, dim, frames] + embeddings = F.interpolate(embeddings, size=(seq_length,), mode='linear', align_corners=False) + embeddings = embeddings.permute(0, 2, 1) # Revert to [batch, frames, dim] + + # Ensure the interpolated embeddings match the sequence length of x + if embeddings.shape[1] != seq_length: + raise ValueError(f"Interpolated embeddings sequence length {embeddings.shape[1]} does not match x's sequence length {seq_length}") + + if x.shape[0] != embeddings.shape[0]: + x = x + embeddings.repeat(x.shape[0]//embeddings.shape[0],1,1) * self.scale + else: + # Now embeddings should have the shape [batch, seq_length, dim] matching x + x = x + embeddings * self.scale # Assuming broadcasting is desired over the batch and dim dimensions + + return x + + +def inject_motion_embeddings(model, combinations=None, config=None): + spatial_shape=np.array([config.dataset.height,config.dataset.width]) + shape32 = np.ceil(spatial_shape/32).astype(int) + shape16 = np.ceil(spatial_shape/16).astype(int) + spatial_name = 'vSpatial' + replacement_dict = {} + # support for 32 frames + max_seq_length = 32 + inject_layers = [] + for name, module in model.named_modules(): + + # check if the module is temp_attention + PETemporal = '.temp_attentions.' in name + + if not(PETemporal and re.search(r'transformer_blocks\.\d+$', name)): + continue + + if not ([name.split('_')[0], module.norm1.normalized_shape[0]] in combinations): + continue + + replacement_dict[f'{name}.pos_embed'] = MotionEmbedding(max_seq_length=max_seq_length, embed_dim=module.norm1.normalized_shape[0]).to(dtype=model.dtype, device=model.device) + + replacement_keys = list(set(replacement_dict.keys())) + temp_attn_list = [name.replace('pos_embed','attn1') for name in replacement_keys] + \ + [name.replace('pos_embed','attn2') for name in replacement_keys] + embed_dims = [replacement_dict[replacement_keys[i]].embed.shape[2] for i in range(len(replacement_keys))] + + for temp_attn_index,temp_attn in enumerate(temp_attn_list): + place_in_net = temp_attn.split('_')[0] + pattern = r'(\d+)\.temp_attentions' + match = re.search(pattern, temp_attn) + place_in_net = temp_attn.split('_')[0] + index_in_net = match.group(1) + h,w = None,None + if place_in_net == 'up': + if index_in_net == "1": + h, w = shape32 + elif index_in_net == "2": + h, w = shape16 + elif place_in_net == 'down': + if index_in_net == "1": + h, w = shape16 + elif index_in_net == "2": + h, w = shape32 + + replacement_dict[temp_attn+'.'+spatial_name] = \ + MotionEmbedding( + wh=h*w, + embed_dim=embed_dims[temp_attn_index%len(replacement_keys)] + ).to(dtype=model.dtype, device=model.device) + + for name, new_module in replacement_dict.items(): + parent_name = name.rsplit('.', 1)[0] if '.' in name else '' + module_name = name.rsplit('.', 1)[-1] + parent_module = model + if parent_name: + parent_module = dict(model.named_modules())[parent_name] + + if [parent_name.split('_')[0], new_module.embed.shape[-1]] in combinations: + inject_layers.append(name) + setattr(parent_module, module_name, new_module) + + inject_layers = list(set(inject_layers)) + for name in inject_layers: + print(f"Injecting motion embedding at {name}") + + parameters_list = [] + for name, para in model.named_parameters(): + if 'pos_embed' in name or spatial_name in name: + parameters_list.append(para) + para.requires_grad = True + else: + para.requires_grad = False + + return parameters_list, inject_layers + +def save_motion_embeddings(model, file_path): + # Extract motion embedding from all instances of MotionEmbedding + motion_embeddings = { + name: module.embed + for name, module in model.named_modules() + if isinstance(module, MotionEmbedding) or isinstance(module, MotionEmbeddingSpatial) + } + # Save the motion embeddings to the specified file path + torch.save(motion_embeddings, file_path) + +def load_motion_embeddings(model, saved_embeddings): + for key, embedding in saved_embeddings.items(): + # Extract parent module and module name from the key + parent_name = key.rsplit('.', 1)[0] if '.' in key else '' + module_name = key.rsplit('.', 1)[-1] + + # Retrieve the parent module + parent_module = model + if parent_name: + parent_module = dict(model.named_modules())[parent_name] + + # Create a new MotionEmbedding instance with the correct dimensions + + new_module = MotionEmbedding(wh = embedding.shape[0],embed_dim=embedding.shape[-1], max_seq_length=embedding.shape[-2]) + + # Properly assign the loaded embeddings to the 'embed' parameter wrapped in nn.Parameter + # Ensure the embedding is on the correct device and has the correct dtype + new_module.embed = nn.Parameter(embedding.to(dtype=model.dtype, device=model.device)) + + # Replace the corresponding module in the model with the new MotionEmbedding instance + setattr(parent_module, module_name, new_module) + +def set_motion_embedding_scale(model, scale_value): + # Iterate over all modules in the model + for _, module in model.named_modules(): + # Check if the module is an instance of MotionEmbedding + if isinstance(module, MotionEmbedding): + # Set the scale attribute to the specified value + module.scale = scale_value + +def set_motion_embedding_length(model, trained_length): + # Iterate over all modules in the model + for _, module in model.named_modules(): + # Check if the module is an instance of MotionEmbedding + if isinstance(module, MotionEmbedding): + # Set the length to the specified value + module.trained_length = trained_length + + + + + diff --git a/models/unet/unet_3d_blocks.py b/models/unet/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8e246dbf36b985afa4537bef3810f95289975415 --- /dev/null +++ b/models/unet/unet_3d_blocks.py @@ -0,0 +1,842 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel +from diffusers.models.transformer_temporal import TransformerTemporalModel + +# Assign gradient checkpoint function to simple variable for readability. +g_c = checkpoint.checkpoint + +def use_temporal(module, num_frames, x): + if num_frames == 1: + if isinstance(module, TransformerTemporalModel): + return {"sample": x} + else: + return x + +def custom_checkpoint(module, mode=None): + if mode == None: raise ValueError('Mode for gradient checkpointing cannot be none.') + custom_forward = None + + if mode == 'resnet': + def custom_forward(hidden_states, temb): + inputs = module(hidden_states, temb) + return inputs + + if mode == 'attn': + def custom_forward( + hidden_states, + encoder_hidden_states=None, + cross_attention_kwargs=None + ): + inputs = module( + hidden_states, + encoder_hidden_states, + cross_attention_kwargs + ) + return inputs + + if mode == 'temp': + def custom_forward(hidden_states, num_frames=None): + inputs = use_temporal(module, num_frames, hidden_states) + if inputs is None: inputs = module( + hidden_states, + num_frames=num_frames + ) + return inputs + + return custom_forward + +def transformer_g_c(transformer, sample, num_frames): + sample = g_c(custom_checkpoint(transformer, mode='temp'), + sample, num_frames, use_reentrant=False + )['sample'] + + return sample + +def cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=False + ): + + def ordered_g_c(idx): + + # Self and CrossAttention + if idx == 0: return g_c(custom_checkpoint(attn, mode='attn'), + hidden_states, encoder_hidden_states,cross_attention_kwargs, use_reentrant=False + )['sample'] + + # Temporal Self and CrossAttention + if idx == 1: return g_c(custom_checkpoint(temp_attn, mode='temp'), + hidden_states, num_frames, use_reentrant=False)['sample'] + + # Resnets + if idx == 2: return g_c(custom_checkpoint(resnet, mode='resnet'), + hidden_states, temb, use_reentrant=False) + + # Temporal Convolutions + if idx == 3: return g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, num_frames, use_reentrant=False + ) + + # Here we call the function depending on the order in which they are called. + # For some layers, the orders are different, so we access the appropriate one by index. + + if not inverse_temp: + for idx in [0,1,2,3]: hidden_states = ordered_g_c(idx) + else: + for idx in [2,3,0,1]: hidden_states = ordered_g_c(idx) + + return hidden_states + +def up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames): + hidden_states = g_c(custom_checkpoint(resnet, mode='resnet'), hidden_states, temb, use_reentrant=False) + hidden_states = g_c(custom_checkpoint(temp_conv, mode='temp'), + hidden_states, num_frames, use_reentrant=False + ) + return hidden_states + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=True, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=True, + upcast_attention=False, + ): + super().__init__() + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + if self.gradient_checkpointing: + hidden_states = up_down_g_c( + self.resnets[0], + self.temp_convs[0], + hidden_states, + temb, + num_frames + ) + else: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + self.gradient_checkpointing = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, num_frames=1): + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.gradient_checkpointing = False + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + attentions.append( + Transformer2DModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // attn_num_head_channels, + attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + ): + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = cross_attn_g_c( + attn, + temp_attn, + resnet, + temp_conv, + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + temb, + num_frames, + inverse_temp=True + ) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if num_frames > 1: + hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.gradient_checkpointing = False + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1 + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.gradient_checkpointing: + hidden_states = up_down_g_c(resnet, temp_conv, hidden_states, temb, num_frames) + else: + hidden_states = resnet(hidden_states, temb) + + if num_frames > 1: + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/models/unet/unet_3d_condition.py b/models/unet/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..d1bf3c6515df3bfbf21e8b6e272b1451afe0cc25 --- /dev/null +++ b/models/unet/unet_3d_condition.py @@ -0,0 +1,500 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + transformer_g_c +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + ): + super().__init__() + + self.sample_size = sample_size + self.gradient_checkpointing = False + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + self.mid_block.gradient_checkpointing = value + for module in self.down_blocks + self.up_blocks: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + if num_frames > 1: + if self.gradient_checkpointing: + sample = transformer_g_c(self.transformer_in, sample, num_frames) + else: + sample = self.transformer_in(sample, num_frames=num_frames).sample + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/noise_init/__init__.py b/noise_init/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c55b1cb491d217bf6756ea4b44cd08c71004d9da --- /dev/null +++ b/noise_init/__init__.py @@ -0,0 +1,4 @@ +from .freq_init import FreqInit +from .blend_init import BlendInit +from .blend_freq_init import BlendFreqInit +from .fft_init import FFTInit \ No newline at end of file diff --git a/noise_init/__pycache__/__init__.cpython-310.pyc b/noise_init/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c35aee7006bbccef5b38a7f9c1b9798629b53177 Binary files /dev/null and b/noise_init/__pycache__/__init__.cpython-310.pyc differ diff --git a/noise_init/__pycache__/blend_freq_init.cpython-310.pyc b/noise_init/__pycache__/blend_freq_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42fb2597636abbed74542d1433a93422c0aa08e Binary files /dev/null and b/noise_init/__pycache__/blend_freq_init.cpython-310.pyc differ diff --git a/noise_init/__pycache__/blend_init.cpython-310.pyc b/noise_init/__pycache__/blend_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aacd4bebe5485cf7c909e9f29b85d7ebb1fc4b3 Binary files /dev/null and b/noise_init/__pycache__/blend_init.cpython-310.pyc differ diff --git a/noise_init/__pycache__/fft_init.cpython-310.pyc b/noise_init/__pycache__/fft_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60a32485d5cb1431967b8c97cce249693d490832 Binary files /dev/null and b/noise_init/__pycache__/fft_init.cpython-310.pyc differ diff --git a/noise_init/__pycache__/freq_init.cpython-310.pyc b/noise_init/__pycache__/freq_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..072f349980e68a64f97509a00a88e1bd25e09076 Binary files /dev/null and b/noise_init/__pycache__/freq_init.cpython-310.pyc differ diff --git a/noise_init/blend_freq_init.py b/noise_init/blend_freq_init.py new file mode 100644 index 0000000000000000000000000000000000000000..cc8d420f7e7779fe4148a61a16e871f573848725 --- /dev/null +++ b/noise_init/blend_freq_init.py @@ -0,0 +1,45 @@ +import math + +import torch +import torch.fft as fft +import torch.nn.functional as F + +from einops import rearrange + + +def BlendFreqInit(noisy_latent, noise, noise_prior=0.5, downsample_factor=4): + f = noisy_latent.shape[2] + new_h, new_w = ( + noisy_latent.shape[-2] // downsample_factor, + noisy_latent.shape[-1] // downsample_factor, + ) + + noise = rearrange(noise, "b c f h w -> (b f) c h w") + noise_down = F.interpolate(noise, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True) + noise_up = F.interpolate( + noise_down, size=(noise.shape[-2], noise.shape[-1]), mode="bilinear", align_corners=True, antialias=True + ) + noise_high_freqs = noise - noise_up + + + noisy_latent = rearrange(noisy_latent, "b c f h w -> (b f) c h w") + noisy_latent_down = F.interpolate( + noisy_latent, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True + ) + latents_low_freqs = F.interpolate( + noisy_latent_down, + size=(noisy_latent.shape[-2], noisy_latent.shape[-1]), + mode="bilinear", + align_corners=True, + antialias=True, + ) + + latent_high_freqs = noisy_latent - latents_low_freqs + + noisy_latent = latents_low_freqs + (noise_prior) ** 0.5 * latent_high_freqs + ( + 1-noise_prior) ** 0.5 * noise_high_freqs + + + noisy_latent = rearrange(noisy_latent, "(b f) c h w -> b c f h w", f=f) + + return noisy_latent \ No newline at end of file diff --git a/noise_init/blend_init.py b/noise_init/blend_init.py new file mode 100644 index 0000000000000000000000000000000000000000..12be37662c7fd7613cfb9776ae7c19bc84c7ca12 --- /dev/null +++ b/noise_init/blend_init.py @@ -0,0 +1,10 @@ +""" +https://arxiv.org/abs/2310.08465 +""" + +def BlendInit(noisy_latent, noise, noise_prior=0.5): + + latents = (noise_prior) ** 0.5 * noisy_latent + ( + 1-noise_prior) ** 0.5 * noise + + return latents \ No newline at end of file diff --git a/noise_init/fft_init.py b/noise_init/fft_init.py new file mode 100644 index 0000000000000000000000000000000000000000..19f3162cd1511b530fe7afcd041809fb2a0995b9 --- /dev/null +++ b/noise_init/fft_init.py @@ -0,0 +1,179 @@ +""" +https://arxiv.org/abs/2312.07537 +""" + +import math + +import torch +import torch.fft as fft +import torch.nn.functional as F + +def freq_mix_3d(x, noise, LPF): + """ + Noise reinitialization. + + Args: + x: diffused latent + noise: randomly sampled noise + LPF: low pass filter + """ + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + +def get_freq_filter(shape, device, filter_type, n, d_s, d_t): + """ + Form the frequency filter for noise reinitialization. + + Args: + shape: shape of latent (B, C, T, H, W) + filter_type: type of the freq filter + n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + if filter_type == "gaussian": + return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "ideal": + return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "box": + return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "butterworth": + return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) + else: + raise NotImplementedError + +def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the gaussian low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) + return mask + +def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): + """ + Compute the butterworth low pass filter mask. + + Args: + shape: shape of the filter (volume) + n: order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) + return mask + +def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) + mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 + return mask + +def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask (approximated version). + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s==0 or d_t==0: + return mask + + threshold_s = round(int(H // 2) * d_s) + threshold_t = round(T // 2 * d_t) + + cframe, crow, ccol = T // 2, H // 2, W //2 + mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 + + return mask + +@torch.no_grad() +def init_filter(video_length, height, width, filter_params_method="gaussian", filter_params_n=4, filter_params_d_s=0.25, filter_params_d_t=0.25, num_channels_latents=4, device='cpu'): + # initialize frequency filter for noise reinitialization + batch_size = 1 + num_channels_latents = num_channels_latents + filter_shape = [ + batch_size, + num_channels_latents, + video_length, + height, + width, + ] + freq_filter = get_freq_filter( + filter_shape, + device=device, + filter_type=filter_params_method, + n=filter_params_n if filter_params_method=="butterworth" else None, + d_s=filter_params_d_s, + d_t=filter_params_d_t + ) + return freq_filter + +def FFTInit(noisy_latent, noise): + + dtype = noisy_latent.dtype + freq_filter = init_filter( + video_length=noisy_latent.shape[2], + height=noisy_latent.shape[3], + width=noisy_latent.shape[4], + device=noisy_latent.device + ) + + # make it float32 to accept any kinds of resolution + latents = freq_mix_3d(noisy_latent.to(dtype=torch.float32), noise.to(dtype=torch.float32), LPF=freq_filter) + latents = latents.to(dtype) + + return latents \ No newline at end of file diff --git a/noise_init/freq_init.py b/noise_init/freq_init.py new file mode 100644 index 0000000000000000000000000000000000000000..08927ad8f77ecf69e46c1a61c0ba48c67a26ca28 --- /dev/null +++ b/noise_init/freq_init.py @@ -0,0 +1,40 @@ +""" +https://arxiv.org/abs/2311.17009 +""" + +import math + +import torch +import torch.fft as fft +import torch.nn.functional as F + +from einops import rearrange + + +def FreqInit(noisy_latent, noise, downsample_factor=4, num_frames=24): + + new_h, new_w = ( + noisy_latent.shape[-2] // downsample_factor, + noisy_latent.shape[-1] // downsample_factor, + ) + noise = rearrange(noise, "b c f h w -> (b f) c h w") + noise_down = F.interpolate(noise, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True) + noise_up = F.interpolate( + noise_down, size=(noise.shape[-2], noise.shape[-1]), mode="bilinear", align_corners=True, antialias=True + ) + high_freqs = noise - noise_up + noisy_latent = rearrange(noisy_latent, "b c f h w -> (b f) c h w") + noisy_latent_down = F.interpolate( + noisy_latent, size=(new_h, new_w), mode="bilinear", align_corners=True, antialias=True + ) + low_freqs = F.interpolate( + noisy_latent_down, + size=(noisy_latent.shape[-2], noisy_latent.shape[-1]), + mode="bilinear", + align_corners=True, + antialias=True, + ) + noisy_latent = low_freqs + high_freqs + noisy_latent = rearrange(noisy_latent, "(b f) c h w -> b c f h w", f=num_frames) + + return noisy_latent \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc876aed19e6a16b676d89aa0282b194c7e3cc0 --- /dev/null +++ b/train.py @@ -0,0 +1,467 @@ +import argparse +import logging +import math +import os +import gc +import copy + +from omegaconf import OmegaConf + +import torch +import torch.utils.checkpoint +import diffusers +import transformers +from tqdm.auto import tqdm + +from accelerate import Accelerator +from accelerate.logging import get_logger + +from models.unet.unet_3d_condition import UNet3DConditionModel +from diffusers.models import AutoencoderKL +from diffusers import DDIMScheduler, TextToVideoSDPipeline + + +from transformers import CLIPTextModel, CLIPTokenizer +from utils.ddim_utils import inverse_video +from utils.gpu_utils import handle_memory_attention, unet_and_text_g_c +from utils.func_utils import * + +import imageio +import numpy as np + +from dataset import * +from loss import * +from noise_init import * + +from attn_ctrl import register_attention_control + +import shutil +logger = get_logger(__name__, log_level="INFO") + +def log_validation(accelerator, config, batch, global_step, text_prompt, unet, text_encoder, vae, output_dir): + with accelerator.autocast(): + unet.eval() + text_encoder.eval() + unet_and_text_g_c(unet, text_encoder, False, False) + + # handle spatial lora + if config.loss.type =='DebiasedHybrid': + loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) + for lora_i in loras: + lora_i.scale = 0 + + pipeline = TextToVideoSDPipeline.from_pretrained( + config.model.pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet + ) + + prompt_list = text_prompt if len(config.val.prompt) <= 0 else config.val.prompt + for seed in config.val.seeds: + noisy_latent = batch['inversion_noise'] + shape = noisy_latent.shape + noise = torch.randn( + shape, + device=noisy_latent.device, + generator=torch.Generator(noisy_latent.device).manual_seed(seed) + ).to(noisy_latent.dtype) + + # handle different noise initialization strategy + init_func_name = f'{config.noise_init.type}' + # Assuming config.dataset is a DictConfig object + init_params_dict = OmegaConf.to_container(config.noise_init, resolve=True) + # Remove the 'type' key + init_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist + + init_func_to_call = globals().get(init_func_name) + init_noise = init_func_to_call(noisy_latent, noise, **init_params_dict) + + for prompt in prompt_list: + file_name = f"{prompt.replace(' ', '_')}_seed_{seed}.mp4" + file_path = f"{output_dir}/samples_{global_step}/" + if not os.path.exists(file_path): + os.makedirs(file_path) + + with torch.no_grad(): + video_frames = pipeline( + prompt=prompt, + negative_prompt=config.val.negative_prompt, + width=config.val.width, + height=config.val.height, + num_frames=config.val.num_frames, + num_inference_steps=config.val.num_inference_steps, + guidance_scale=config.val.guidance_scale, + latents=init_noise, + ).frames[0] + export_to_video(video_frames, os.path.join(file_path, file_name), config.dataset.fps) + logger.info(f"Saved a new sample to {os.path.join(file_path, file_name)}") + del pipeline + torch.cuda.empty_cache() + +def create_logging(logging, logger, accelerator): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + +def accelerate_set_verbose(accelerator): + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + +def export_to_video(video_frames, output_video_path, fps): + video_writer = imageio.get_writer(output_video_path, fps=fps) + for img in video_frames: + video_writer.append_data(np.array(img)) + video_writer.close() + return output_video_path + +def create_output_folders(output_dir, config): + out_dir = os.path.join(output_dir) + os.makedirs(out_dir, exist_ok=True) + OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) + shutil.copyfile(config.dataset.single_video_path, os.path.join(out_dir,'source.mp4')) + return out_dir + +def load_primary_models(pretrained_model_path): + noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + return noise_scheduler, tokenizer, text_encoder, vae, unet + +def freeze_models(models_to_freeze): + for model in models_to_freeze: + if model is not None: model.requires_grad_(False) + +def is_mixed_precision(accelerator): + weight_dtype = torch.float32 + + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + return weight_dtype + +def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): + for model in model_list: + if model is not None: model.to(accelerator.device, dtype=weight_dtype) + +def handle_cache_latents( + should_cache, + output_dir, + train_dataloader, + train_batch_size, + vae, + unet, + pretrained_model_path, + cached_latent_dir=None, +): + # Cache latents by storing them in VRAM. + # Speeds up training and saves memory by not encoding during the train loop. + if not should_cache: return None + vae.to('cuda', dtype=torch.float16) + vae.enable_slicing() + + pipe = TextToVideoSDPipeline.from_pretrained( + pretrained_model_path, + vae=vae, + unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16) + ) + pipe.text_encoder.to('cuda', dtype=torch.float16) + + cached_latent_dir = ( + os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None + ) + + if cached_latent_dir is None: + cache_save_dir = f"{output_dir}/cached_latents" + os.makedirs(cache_save_dir, exist_ok=True) + + for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): + save_name = f"cached_{i}" + full_out_path = f"{cache_save_dir}/{save_name}.pt" + + pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) + batch['latents'] = tensor_to_vae_latent(pixel_values, vae) + + batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50) + for k, v in batch.items(): batch[k] = v[0] + + torch.save(batch, full_out_path) + del pixel_values + del batch + + # We do this to avoid fragmentation from casting latents between devices. + torch.cuda.empty_cache() + else: + cache_save_dir = cached_latent_dir + + return torch.utils.data.DataLoader( + CachedDataset(cache_dir=cache_save_dir), + batch_size=train_batch_size, + shuffle=True, + num_workers=0 + ) + +def should_sample(global_step, validation_steps, validation_data): + return (global_step == 1 or global_step % validation_steps == 0) and validation_data.sample_preview + +def save_pipe( + path, + global_step, + accelerator, + unet, + text_encoder, + vae, + output_dir, + is_checkpoint=False, + save_pretrained_model=False, + **extra_params +): + if is_checkpoint: + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") + os.makedirs(save_path, exist_ok=True) + else: + save_path = output_dir + + # Save the dtypes so we can continue training at the same precision. + u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype + + # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. + unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False)) + text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False)) + + pipeline = TextToVideoSDPipeline.from_pretrained( + path, + unet=unet_out, + text_encoder=text_encoder_out, + vae=vae, + ).to(torch_dtype=torch.float32) + + lora_managers_spatial = extra_params.get('lora_managers_spatial', [None]) + lora_manager_spatial = lora_managers_spatial[-1] + if lora_manager_spatial is not None: + lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step) + + save_motion_embeddings(unet_out, os.path.join(save_path, 'motion_embed.pt')) + + if save_pretrained_model: + pipeline.save_pretrained(save_path) + + if is_checkpoint: + unet, text_encoder = accelerator.prepare(unet, text_encoder) + models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] + [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] + + logger.info(f"Saved model at {save_path} on step {global_step}") + + del pipeline + del unet_out + del text_encoder_out + torch.cuda.empty_cache() + gc.collect() + +def main(config): + # Initialize the Accelerator + accelerator = Accelerator( + gradient_accumulation_steps=config.train.gradient_accumulation_steps, + mixed_precision=config.train.mixed_precision, + log_with=config.train.logger_type, + project_dir=config.train.output_dir + ) + + video_path = config.dataset.single_video_path + cap = cv2.VideoCapture(video_path) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = 8 + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + config.dataset.width = width + config.dataset.height = height + config.dataset.fps = fps + config.dataset.n_sample_frames = frame_count + + config.dataset.single_video_path = video_path + + config.val.width = width + config.val.height = height + config.val.num_frames = frame_count + + # Create output directories and set up logging + if accelerator.is_main_process: + output_dir = create_output_folders(config.train.output_dir, config) + create_logging(logging, logger, accelerator) + accelerate_set_verbose(accelerator) + + # Load primary models + noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(config.model.pretrained_model_path) + # Load videoCrafter2 unet for better video quality, if needed + if config.model.unet == 'videoCrafter2': + unet = UNet3DConditionModel.from_pretrained("/hpc2hdd/home/lwang592/ziyang/cache/videocrafterv2",subfolder='unet') + elif config.model.unet == 'zeroscope_v2_576w': + # by default, we use zeroscope_v2_576w, thus this unet is already loaded + pass + else: + raise ValueError("Invalid UNet model") + + freeze_models([vae, text_encoder]) + handle_memory_attention(unet) + + train_dataloader, train_dataset = prepare_data(config, tokenizer) + + # Handle latents caching + cached_data_loader = handle_cache_latents( + config.train.cache_latents, + output_dir, + train_dataloader, + config.train.train_batch_size, + vae, + unet, + config.model.pretrained_model_path, + config.train.cached_latent_dir, + ) + if cached_data_loader is not None: + train_dataloader = cached_data_loader + + # Prepare parameters and optimization + params, extra_params = prepare_params(unet, config, train_dataset) + optimizers, lr_schedulers = prepare_optimizers(params, config, **extra_params) + + + # Prepare models and data for training + unet, optimizers, train_dataloader, lr_schedulers, text_encoder = accelerator.prepare( + unet, optimizers, train_dataloader, lr_schedulers, text_encoder + ) + + # Additional model setups + unet_and_text_g_c(unet, text_encoder) + vae.enable_slicing() + + # Setup for mixed precision training + weight_dtype = is_mixed_precision(accelerator) + cast_to_gpu_and_type([text_encoder, vae], accelerator, weight_dtype) + + # Recalculate training steps and epochs + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.train.gradient_accumulation_steps) + num_train_epochs = math.ceil(config.train.max_train_steps / num_update_steps_per_epoch) + + # Initialize trackers and store configuration + if accelerator.is_main_process: + accelerator.init_trackers("motion-inversion") + + # Train! + total_batch_size = config.train.train_batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {config.train.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {config.train.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, config.train.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + # Register the attention control, for Motion Value Embedding(s) + register_attention_control(unet, config=config) + for epoch in range(first_epoch, num_train_epochs): + train_loss_temporal = 0.0 + + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if config.train.resume_from_checkpoint and epoch == first_epoch and step < config.train.resume_step: + if step % config.train.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): + + for optimizer in optimizers: + optimizer.zero_grad(set_to_none=True) + + with accelerator.autocast(): + if global_step == 0: + unet.train() + + loss_func_to_call = globals().get(f'{config.loss.type}') + + loss_temporal, train_loss_temporal = loss_func_to_call( + train_loss_temporal, + accelerator, + optimizers, + lr_schedulers, + unet, + vae, + text_encoder, + noise_scheduler, + batch, + step, + config + ) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss_temporal}, step=global_step) + train_loss_temporal = 0.0 + if global_step % config.train.checkpointing_steps == 0 and global_step > 0: + save_pipe( + config.model.pretrained_model_path, + global_step, + accelerator, + unet, + text_encoder, + vae, + output_dir, + is_checkpoint=True, + **extra_params + ) + + if loss_temporal is not None: + accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step) + + if global_step >= config.train.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default='configs/config.yaml') + parser.add_argument("--single_video_path", type=str) + parser.add_argument("--prompts", type=str, help="JSON string of prompts") + args = parser.parse_args() + + # Load and merge configurations + config = OmegaConf.load(args.config) + + # Update the config with the command-line arguments + if args.single_video_path: + config.dataset.single_video_path = args.single_video_path + # Set the output dir + config.train.output_dir = os.path.join(config.train.output_dir, os.path.basename(args.single_video_path).split('.')[0]) + + if args.prompts: + config.val.prompt = json.loads(args.prompts) + + + + main(config) diff --git a/utils/__pycache__/convert_diffusers_to_original_ms_text_to_video.cpython-310.pyc b/utils/__pycache__/convert_diffusers_to_original_ms_text_to_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..028728ecfb6d1cf43a80f9f0b83dd0807e74e31d Binary files /dev/null and b/utils/__pycache__/convert_diffusers_to_original_ms_text_to_video.cpython-310.pyc differ diff --git a/utils/__pycache__/dataset_utils.cpython-310.pyc b/utils/__pycache__/dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bac8163cbdb8e485aabe4e34a67777af174bf58a Binary files /dev/null and b/utils/__pycache__/dataset_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/ddim_utils.cpython-310.pyc b/utils/__pycache__/ddim_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caffd14f2e2834585458a6dc8047579cfae9582a Binary files /dev/null and b/utils/__pycache__/ddim_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/func_utils.cpython-310.pyc b/utils/__pycache__/func_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e108c44fc77cfcbf60def9b43d27dd812920bcc3 Binary files /dev/null and b/utils/__pycache__/func_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/gpu_utils.cpython-310.pyc b/utils/__pycache__/gpu_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2751c6719884a3f06053337ad16d51df486e4e3 Binary files /dev/null and b/utils/__pycache__/gpu_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/lora.cpython-310.pyc b/utils/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c17259183c16e8b0a211e52a1d27a98b12443bbf Binary files /dev/null and b/utils/__pycache__/lora.cpython-310.pyc differ diff --git a/utils/__pycache__/lora_handler.cpython-310.pyc b/utils/__pycache__/lora_handler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..841afa8bade334da2ce0f86b1a67630b38cd8a7e Binary files /dev/null and b/utils/__pycache__/lora_handler.cpython-310.pyc differ diff --git a/utils/convert_diffusers_to_original_ms_text_to_video.py b/utils/convert_diffusers_to_original_ms_text_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..83758fe0c0d06d5824dcd26215f517a998fb94c8 --- /dev/null +++ b/utils/convert_diffusers_to_original_ms_text_to_video.py @@ -0,0 +1,465 @@ +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, and Text Encoder. +# Does not convert optimizer state or any other thing. + +import argparse +import os.path as osp +import re + +import torch +from safetensors.torch import load_file, save_file + +# =================# +# UNet Conversion # +# =================# + +print ('Initializing the conversion map') + +unet_conversion_map = [ + # (ModelScope, HF Diffusers) + + # from Vanilla ModelScope/StableDiffusion + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + + + # from Vanilla ModelScope/StableDiffusion + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + + + # from Vanilla ModelScope/StableDiffusion + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (ModelScope, HF Diffusers) + + # SD + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + + # MS + #("temopral_conv", "temp_convs"), # ROFL, they have a typo here --kabachuha +] + +unet_conversion_map_layer = [] + +# Convert input TemporalTransformer +unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in')) + +# Reference for the default settings + +# "model_cfg": { +# "unet_in_dim": 4, +# "unet_dim": 320, +# "unet_y_dim": 768, +# "unet_context_dim": 1024, +# "unet_out_dim": 4, +# "unet_dim_mult": [1, 2, 4, 4], +# "unet_num_heads": 8, +# "unet_head_dim": 64, +# "unet_res_blocks": 2, +# "unet_attn_scales": [1, 0.5, 0.25], +# "unet_dropout": 0.1, +# "temporal_attention": "True", +# "num_timesteps": 1000, +# "mean_type": "eps", +# "var_type": "fixed_small", +# "loss_type": "mse" +# } + +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + + # Spacial SD stuff + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + # Temporal MS stuff + hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0.temopral_conv." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.2." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + + # Spacial SD stuff + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0.temopral_conv." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.2." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + # Up/Downsamplers are 2D, so don't need to touch them + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 3}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + +# Handle the middle block + +# Spacial +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{3*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +# Temporal +hf_mid_atn_prefix = "mid_block.temp_attentions.0." +sd_mid_atn_prefix = "middle_block.2." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.temp_convs.{j}." + sd_mid_res_prefix = f"middle_block.{3*j}.temopral_conv." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +# The pipeline +def convert_unet_state_dict(unet_state_dict, strict_mapping=False): + print ('Converting the UNET') + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + + for sd_name, hf_name in unet_conversion_map: + if strict_mapping: + if hf_name in mapping: + mapping[hf_name] = sd_name + else: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + # elif "temp_convs" in k: + # for sd_part, hf_part in unet_conversion_map_resnet: + # v = v.replace(hf_part, sd_part) + # mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + + + # there must be a pattern, but I don't want to bother atm + do_not_unsqueeze = [f'output_blocks.{i}.1.proj_out.weight' for i in range(3, 12)] + [f'output_blocks.{i}.1.proj_in.weight' for i in range(3, 12)] + ['middle_block.1.proj_in.weight', 'middle_block.1.proj_out.weight'] + [f'input_blocks.{i}.1.proj_out.weight' for i in [1, 2, 4, 5, 7, 8]] + [f'input_blocks.{i}.1.proj_in.weight' for i in [1, 2, 4, 5, 7, 8]] + print (do_not_unsqueeze) + + new_state_dict = {v: (unet_state_dict[k].unsqueeze(-1) if ('proj_' in k and ('bias' not in k) and (k not in do_not_unsqueeze)) else unet_state_dict[k]) for k, v in mapping.items()} + # HACK: idk why the hell it does not work with list comprehension + for k, v in new_state_dict.items(): + has_k = False + for n in do_not_unsqueeze: + if k == n: + has_k = True + + if has_k: + v = v.squeeze(-1) + new_state_dict[k] = v + + return new_state_dict + +# TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha + +# =========================# +# Text Encoder Conversion # +# =========================# + +# IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha + +# =========================# +# Text Encoder Conversion # +# =========================# + + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + #print ('Converting the text encoder') + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + +textenc_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + +# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp +code2idx = {"q": 0, "k": 1, "v": 2} + + +def convert_text_enc_state_dict_v20(text_enc_dict): + new_state_dict = {} + capture_qkv_weight = {} + capture_qkv_bias = {} + for k, v in text_enc_dict.items(): + if ( + k.endswith(".self_attn.q_proj.weight") + or k.endswith(".self_attn.k_proj.weight") + or k.endswith(".self_attn.v_proj.weight") + ): + k_pre = k[: -len(".q_proj.weight")] + k_code = k[-len("q_proj.weight")] + if k_pre not in capture_qkv_weight: + capture_qkv_weight[k_pre] = [None, None, None] + capture_qkv_weight[k_pre][code2idx[k_code]] = v + continue + + if ( + k.endswith(".self_attn.q_proj.bias") + or k.endswith(".self_attn.k_proj.bias") + or k.endswith(".self_attn.v_proj.bias") + ): + k_pre = k[: -len(".q_proj.bias")] + k_code = k[-len("q_proj.bias")] + if k_pre not in capture_qkv_bias: + capture_qkv_bias[k_pre] = [None, None, None] + capture_qkv_bias[k_pre][code2idx[k_code]] = v + continue + + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) + new_state_dict[relabelled_key] = v + + for k_pre, tensors in capture_qkv_weight.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + + for k_pre, tensors in capture_qkv_bias.items(): + if None in tensors: + raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") + relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) + new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + + return new_state_dict + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--clip_checkpoint_path", default=None, type=str, help="Path to the output CLIP model.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) + + args = parser.parse_args() + + assert args.model_path is not None, "Must provide a model path!" + + assert args.checkpoint_path is not None, "Must provide a checkpoint path!" + + assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!" + + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + #vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + # if osp.exists(vae_path): + # vae_state_dict = load_file(vae_path, device="cpu") + # else: + # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + # vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict(unet_state_dict) + #unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + # vae_state_dict = convert_vae_state_dict(vae_state_dict) + # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + + # MODELSCOPE always uses the 2.X encoder, btw --kabachuha + + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) + #text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + #text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha + # Save CLIP and the Diffusion model to their own files + + #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + print ('Saving UNET') + state_dict = {**unet_state_dict} + + if args.half: + state_dict = {k: v.half() for k, v in state_dict.items()} + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + #state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path) + + # TODO: CLIP conversion doesn't work atm + # print ('Saving CLIP') + # state_dict = {**text_enc_dict} + + # if args.half: + # state_dict = {k: v.half() for k, v in state_dict.items()} + + # if args.use_safetensors: + # save_file(state_dict, args.checkpoint_path) + # else: + # #state_dict = {"state_dict": state_dict} + # torch.save(state_dict, args.clip_checkpoint_path) + + print('Operation successfull') diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d1d198b94d79e087b098a419090e9a6d05f014 --- /dev/null +++ b/utils/dataset_utils.py @@ -0,0 +1,113 @@ +import os +import json +import decord +decord.bridge.set_bridge('torch') + +import torch +from torch.utils.data import Dataset +import torchvision +import torchvision.transforms as T + +from itertools import islice + +from glob import glob +from PIL import Image +from einops import rearrange, repeat + + +def read_caption_file(caption_file): + with open(caption_file, 'r', encoding="utf8") as t: + return t.read() + +def get_text_prompt( + text_prompt: str = '', + fallback_prompt: str= '', + file_path:str = '', + ext_types=['.mp4'], + use_caption=False + ): + try: + if use_caption: + if len(text_prompt) > 1: return text_prompt + caption_file = '' + # Use caption on per-video basis (One caption PER video) + for ext in ext_types: + maybe_file = file_path.replace(ext, '.txt') + if maybe_file.endswith(ext_types): continue + if os.path.exists(maybe_file): + caption_file = maybe_file + break + + if os.path.exists(caption_file): + return read_caption_file(caption_file) + + # Return fallback prompt if no conditions are met. + return fallback_prompt + + return text_prompt + except: + print(f"Couldn't read prompt caption for {file_path}. Using fallback.") + return fallback_prompt + +def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24): + max_range = len(vr) + frame_number = sorted((0, start_idx, max_range))[1] + + frame_range = range(frame_number, max_range, sample_rate) + frame_range_indices = list(frame_range)[:max_frames] + + return frame_range_indices + +def get_prompt_ids(prompt, tokenizer): + prompt_ids = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return prompt_ids + +def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch): + if use_bucketing: + vr = decord.VideoReader(vid_path) + resize = get_frame_buckets(vr) + video = get_frame_batch(vr, resize=resize) + + else: + vr = decord.VideoReader(vid_path, width=w, height=h) + video = get_frame_batch(vr) + + return video, vr + +def min_res(size, min_size): return 192 if size < 192 else size + +def up_down_bucket(m_size, in_size, direction): + if direction == 'down': return abs(int(m_size - in_size)) + if direction == 'up': return abs(int(m_size + in_size)) + +def get_bucket_sizes(size, direction: 'down', min_size): + multipliers = [64, 128] + for i, m in enumerate(multipliers): + res = up_down_bucket(m, size, direction) + multipliers[i] = min_res(res, min_size=min_size) + return multipliers + +def closest_bucket(m_size, size, direction, min_size): + lst = get_bucket_sizes(m_size, direction, min_size) + return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] + +def resolve_bucket(i,h,w): return (i / (h / w)) + +def sensible_buckets(m_width, m_height, w, h, min_size=192): + if h > w: + w = resolve_bucket(m_width, h, w) + w = closest_bucket(m_width, w, 'down', min_size=min_size) + return w, m_height + if h < w: + h = resolve_bucket(m_height, w, h) + h = closest_bucket(m_height, h, 'down', min_size=min_size) + return m_width, h + + return m_width, m_height \ No newline at end of file diff --git a/utils/ddim_utils.py b/utils/ddim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0163c123351e1129b2257a38b06a1f63878d0c --- /dev/null +++ b/utils/ddim_utils.py @@ -0,0 +1,76 @@ +import numpy as np +from typing import Union + +import torch + +from tqdm import tqdm +from diffusers import DDIMScheduler + + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents + + +def inverse_video(pipe, latents, num_steps): + ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + ddim_inv_scheduler.set_timesteps(num_steps) + + ddim_inv_latent = ddim_inversion( + pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device), + num_inv_steps=num_steps, prompt="")[-1] + return ddim_inv_latent \ No newline at end of file diff --git a/utils/extract_16frames.py b/utils/extract_16frames.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef02dd27bfb58c1204061e3cfcefe285a735f9d --- /dev/null +++ b/utils/extract_16frames.py @@ -0,0 +1,36 @@ +import cv2 +import imageio + + +def get_total_frames(video_path): + cap = cv2.VideoCapture(video_path) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return total_frames + + +def extract_frames(input_path, output_path, target_fps, selected_frames): + total_frames = get_total_frames(input_path) + + video_reader = imageio.get_reader(input_path) + fps = video_reader.get_meta_data()['fps'] + + target_total_frames = selected_frames + frame_interval = max(1, int(fps / target_fps)) + selected_indices = [int(i * frame_interval) for i in range(target_total_frames)] + + target_frames = [video_reader.get_data(i) for i in selected_indices] + with imageio.get_writer(output_path, fps=target_fps) as video_writer: + for frame in target_frames: + video_writer.append_data(frame) + + +if __name__ == "__main__": + input_video_path = "" + output_video_path = "" + target_fps = 8 + selected_frames = 24 + + extract_frames(input_video_path, output_video_path, target_fps, selected_frames) + + diff --git a/utils/func_utils.py b/utils/func_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a974984254bd36cd8668c695b93e2b6d3df87985 --- /dev/null +++ b/utils/func_utils.py @@ -0,0 +1,231 @@ +import torch +import random +import cv2 +import fnmatch +import torch.nn.functional as F +from torchvision import transforms +import torchvision.transforms.functional as TF +from diffusers.optimization import get_scheduler +from einops import rearrange, repeat +from omegaconf import OmegaConf +from dataset import * +from models.unet.motion_embeddings import * +from .lora import * +from .lora_handler import * + +def find_videos(directory, extensions=('.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.gif')): + video_files = [] + for root, dirs, files in os.walk(directory): + for extension in extensions: + for filename in fnmatch.filter(files, '*' + extension): + video_files.append(os.path.join(root, filename)) + return video_files + +def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): + extra_params = extra_params if len(extra_params.keys()) > 0 else None + return { + "model": model, + "condition": condition, + 'extra_params': extra_params, + 'is_lora': is_lora, + "negation": negation + } + +def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): + params = { + "name": name, + "params": params, + "lr": lr + } + if extra_params is not None: + for k, v in extra_params.items(): + params[k] = v + + return params + +def create_optimizer_params(model_list, lr): + import itertools + optimizer_params = [] + + for optim in model_list: + model, condition, extra_params, is_lora, negation = optim.values() + # Check if we are doing LoRA training. + if is_lora and condition and isinstance(model, list): + params = create_optim_params( + params=itertools.chain(*model), + extra_params=extra_params + ) + optimizer_params.append(params) + continue + + if is_lora and condition and not isinstance(model, list): + for n, p in model.named_parameters(): + if 'lora' in n: + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + continue + + # If this is true, we can train it. + if condition: + for n, p in model.named_parameters(): + should_negate = 'lora' in n and not is_lora + if should_negate: continue + + params = create_optim_params(n, p, lr, extra_params) + optimizer_params.append(params) + + return optimizer_params + +def get_optimizer(use_8bit_adam): + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + return bnb.optim.AdamW8bit + else: + return torch.optim.AdamW + +# Initialize the optimizer +def prepare_optimizers(params, config, **extra_params): + optimizer_cls = get_optimizer(config.train.use_8bit_adam) + + optimizer_temporal = optimizer_cls( + params, + lr=config.loss.learning_rate + ) + + lr_scheduler_temporal = get_scheduler( + config.loss.lr_scheduler, + optimizer=optimizer_temporal, + num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, + num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, + ) + + # Insert Spatial LoRAs + if config.loss.type == 'DebiasedHybrid': + unet_lora_params_spatial_list = extra_params.get('unet_lora_params_spatial_list', []) + spatial_lora_num = extra_params.get('spatial_lora_num', 1) + + optimizer_spatial_list = [] + lr_scheduler_spatial_list = [] + for i in range(spatial_lora_num): + unet_lora_params_spatial = unet_lora_params_spatial_list[i] + + optimizer_spatial = optimizer_cls( + create_optimizer_params( + [ + param_optim( + unet_lora_params_spatial, + config.loss.use_unet_lora, + is_lora=True, + extra_params={**{"lr": config.loss.learning_rate_spatial}} + ) + ], + config.loss.learning_rate_spatial + ), + lr=config.loss.learning_rate_spatial + ) + optimizer_spatial_list.append(optimizer_spatial) + + # Scheduler + lr_scheduler_spatial = get_scheduler( + config.loss.lr_scheduler, + optimizer=optimizer_spatial, + num_warmup_steps=config.loss.lr_warmup_steps * config.train.gradient_accumulation_steps, + num_training_steps=config.train.max_train_steps * config.train.gradient_accumulation_steps, + ) + lr_scheduler_spatial_list.append(lr_scheduler_spatial) + + else: + optimizer_spatial_list = [] + lr_scheduler_spatial_list = [] + + + + return [optimizer_temporal] + optimizer_spatial_list, [lr_scheduler_temporal] + lr_scheduler_spatial_list + +def sample_noise(latents, noise_strength, use_offset_noise=False): + b, c, f, *_ = latents.shape + noise_latents = torch.randn_like(latents, device=latents.device) + + if use_offset_noise: + offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) + noise_latents = noise_latents + noise_strength * offset_noise + + return noise_latents + +@torch.no_grad() +def tensor_to_vae_latent(t, vae): + video_length = t.shape[1] + + t = rearrange(t, "b f c h w -> (b f) c h w") + latents = vae.encode(t).latent_dist.sample() + latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) + latents = latents * 0.18215 + + return latents + +def prepare_data(config, tokenizer): + # Get the training dataset based on types (json, single_video, image) + + # Assuming config.dataset is a DictConfig object + dataset_params_dict = OmegaConf.to_container(config.dataset, resolve=True) + + # Remove the 'type' key + dataset_params_dict.pop('type', None) # 'None' ensures no error if 'type' key doesn't exist + + train_datasets = [] + + # Loop through all available datasets, get the name, then add to list of data to process. + for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: + for dataset in config.dataset.type: + if dataset == DataSet.__getname__(): + train_datasets.append(DataSet(**dataset_params_dict, tokenizer=tokenizer)) + + if len(train_datasets) < 0: + raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") + + train_dataset = train_datasets[0] + + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config.train.train_batch_size, + shuffle=True + ) + + return train_dataloader, train_dataset + +# create parameters for optimziation +def prepare_params(unet, config, train_dataset): + extra_params = {} + + params,embedding_layers = inject_motion_embeddings( + unet, + combinations=config.model.motion_embeddings.combinations, + config=config + ) + + config.model.embedding_layers = embedding_layers + if config.loss.type == "DebiasedHybrid": + if config.loss.spatial_lora_num == -1: + config.loss.spatial_lora_num = train_dataset.__len__() + + lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_all = inject_spatial_loras( + unet=unet, + use_unet_lora=True, + lora_unet_dropout=0.1, + lora_path='', + lora_rank=32, + spatial_lora_num=1, + ) + + extra_params['lora_managers_spatial'] = lora_managers_spatial + extra_params['unet_lora_params_spatial_list'] = unet_lora_params_spatial_list + extra_params['unet_negation_all'] = unet_negation_all + + return params, extra_params \ No newline at end of file diff --git a/utils/gpu_utils.py b/utils/gpu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..174e672fbf6a4e182258322f106c09f7636d078c --- /dev/null +++ b/utils/gpu_utils.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import AttnProcessor2_0 +from diffusers.models.attention import BasicTransformerBlock +from diffusers.utils.import_utils import is_xformers_available +from transformers.models.clip.modeling_clip import CLIPEncoder + +GRADIENT_CHECKPOINTING = True +TEXT_ENCODER_GRADIENT_CHECKPOINTING = True +ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION = True +ENABLE_TORCH_2_ATTN = True + +def is_attn(name): + return ('attn1' or 'attn2' == name.split('.')[-1]) + +def unet_and_text_g_c(unet, text_encoder, unet_enable=GRADIENT_CHECKPOINTING, text_enable=TEXT_ENCODER_GRADIENT_CHECKPOINTING): + unet._set_gradient_checkpointing(value=unet_enable) + text_encoder._set_gradient_checkpointing(CLIPEncoder) + +def set_processors(attentions): + for attn in attentions: attn.set_processor(AttnProcessor2_0()) + +def set_torch_2_attn(unet): + optim_count = 0 + + for name, module in unet.named_modules(): + if is_attn(name): + if isinstance(module, torch.nn.ModuleList): + for m in module: + if isinstance(m, BasicTransformerBlock): + set_processors([m.attn1, m.attn2]) + optim_count += 1 + if optim_count > 0: + print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") + +def handle_memory_attention( + unet, + enable_xformers_memory_efficient_attention=ENABLE_XFORMERS_MEMORY_EFFICIENT_ATTENTION, + enable_torch_2_attn=ENABLE_TORCH_2_ATTN + ): + try: + is_torch_2 = hasattr(F, 'scaled_dot_product_attention') + enable_torch_2 = is_torch_2 and enable_torch_2_attn + + if enable_xformers_memory_efficient_attention and not enable_torch_2: + if is_xformers_available(): + from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if enable_torch_2: + set_torch_2_attn(unet) + + except: + print("Could not enable memory efficient attention for xformers or Torch 2.0.") \ No newline at end of file diff --git a/utils/lora.py b/utils/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac36c1f83558ab99489d3ab64d8fa6c504e714f --- /dev/null +++ b/utils/lora.py @@ -0,0 +1,1481 @@ +import json +import math +from itertools import groupby +import os +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from safetensors.torch import safe_open + from safetensors.torch import save_file as safe_save + + safetensors_available = True +except ImportError: + from .safe_open import safe_open + + def safe_save( + tensors: Dict[str, torch.Tensor], + filename: str, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + raise EnvironmentError( + "Saving safetensors requires the safetensors library. Please install with pip or similar." + ) + + safetensors_available = False + +from diffusers.models.lora import LoRACompatibleLinear + +class LoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 + ): + super().__init__() + + if r > min(in_features, out_features): + #raise ValueError( + # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + #) + print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") + r = min(in_features, out_features) + + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + self.selector = nn.Identity() + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0): + return ( + self.linear(hidden_states) + + self.dropout(self.lora_up(self.selector(self.lora_down(hidden_states)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class MultiLoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, lora_num=1, scales=[1.0] + ): + super().__init__() + + if r > min(in_features, out_features): + #raise ValueError( + # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + #) + print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") + r = min(in_features, out_features) + + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + + for i in range(lora_num): + if i==0: + self.lora_down =[nn.Linear(in_features, r, bias=False)] + self.dropout = [nn.Dropout(dropout_p)] + self.lora_up = [nn.Linear(r, out_features, bias=False)] + self.scale = scales[i] + self.selector = [nn.Identity()] + else: + self.lora_down.append(nn.Linear(in_features, r, bias=False)) + self.dropout.append( nn.Dropout(dropout_p)) + self.lora_up.append( nn.Linear(r, out_features, bias=False)) + self.scale.append(scales[i]) + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.linear(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") + r = min(in_channels, out_channels) + + self.r = r + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv2d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + +class LoraInjectedConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: (3, 1, 1), + padding: (1, 0, 0), + bias: bool = False, + r: int = 4, + dropout_p: float = 0, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + print(f"LoRA rank {r} is too large. setting to: {min(in_channels, out_channels)}") + r = min(in_channels, out_channels) + + self.r = r + self.kernel_size = kernel_size + self.padding = padding + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.lora_down = nn.Conv3d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + bias=False, + padding=padding + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv3d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv3d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + +DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE + +EMBED_FLAG = "" + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = None, + # [ + # LoraInjectedLinear, + # LoraInjectedConv2d, + # LoraInjectedConv3d + # ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for name, module in model.named_modules() + if module.__class__.__name__ in ancestor_class # and ('transformer_in' not in name) + ) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + continue_flag = True + if 'Transformer2DModel' in ancestor_class and ('attn1' in fullname or 'ff' in fullname): + continue_flag = False + if 'TransformerTemporalModel' in ancestor_class and ('attn1' in fullname or 'attn2' in fullname or 'ff' in fullname): + continue_flag = False + if continue_flag: + continue + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + if name in ['lora_up', 'dropout', 'lora_down']: + continue + # Otherwise, yield it + yield parent, name, module + + +def _find_modules_old( + model, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], +): + ret = [] + for _module in model.modules(): + if _module.__class__.__name__ in ancestor_class: + + for name, _child_module in _module.named_modules(): + if _child_module.__class__ in search_class: + ret.append((_module, name, _child_module)) + print(ret) + return ret + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora( + model: nn.Module, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + verbose: bool = False, + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + weight = _child_module.weight + bias = _child_module.bias + if verbose: + print("LoRA Injection : injecting lora into ", name) + print("LoRA Injection : weight shape", weight.shape) + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + if True: + for target_replace_module_i in target_replace_module: + for _module, name, _child_module in _find_modules( + model, [target_replace_module_i], search_class=[LoRACompatibleLinear, nn.Conv2d, nn.Conv3d] + ): + # if name == 'to_q': + # continue + if _child_module.__class__ == LoRACompatibleLinear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv3d( + _child_module.in_channels, + _child_module.out_channels, + bias=_child_module.bias is not None, + kernel_size=_child_module.kernel_size, + padding=_child_module.padding, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + else: + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d, nn.Conv3d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv3d( + _child_module.in_channels, + _child_module.out_channels, + bias=_child_module.bias is not None, + kernel_size=_child_module.kernel_size, + padding=_child_module.padding, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_inferable_lora( + model, + lora_path='', + unet_replace_modules=["UNet3DConditionModel"], + text_encoder_replace_modules=["CLIPEncoderLayer"], + is_extended=False, + r=16 + ): + from transformers.models.clip import CLIPTextModel + from diffusers import UNet3DConditionModel + + def is_text_model(f): return 'text_encoder' in f and isinstance(model.text_encoder, CLIPTextModel) + def is_unet(f): return 'unet' in f and model.unet.__class__.__name__ == "UNet3DConditionModel" + + if os.path.exists(lora_path): + try: + for f in os.listdir(lora_path): + if f.endswith('.pt'): + lora_file = os.path.join(lora_path, f) + + if is_text_model(f): + monkeypatch_or_replace_lora( + model.text_encoder, + torch.load(lora_file), + target_replace_module=text_encoder_replace_modules, + r=r + ) + print("Successfully loaded Text Encoder LoRa.") + continue + + if is_unet(f): + monkeypatch_or_replace_lora_extended( + model.unet, + torch.load(lora_file), + target_replace_module=unet_replace_modules, + r=r + ) + print("Successfully loaded UNET LoRa.") + continue + + print("Found a .pt file, but doesn't have the correct name format. (unet.pt, text_encoder.pt)") + + except Exception as e: + print(e) + print("Couldn't inject LoRA's due to an error.") + +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for target_replace_module_i in target_replace_module: + + for _m, _n, _child_module in _find_modules( + model, + [target_replace_module_i], + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def extract_lora_child_module(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for target_replace_module_i in target_replace_module: + + for _m, _n, _child_module in _find_modules( + model, + [target_replace_module_i], + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + loras.append(_child_module) + + return loras + +def extract_lora_as_tensor( + model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True +): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + up, down = _child_module.realize_as_lora() + if as_fp16: + up = up.to(torch.float16) + down = down.to(torch.float16) + + loras.append((up, down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def save_lora_weight( + model, + path="./lora.pt", + target_replace_module=DEFAULT_TARGET_REPLACE, + flag=None +): + weights = [] + for _up, _down in extract_lora_ups_down( + model, target_replace_module=target_replace_module + ): + weights.append(_up.weight.to("cpu").to(torch.float32)) + weights.append(_down.weight.to("cpu").to(torch.float32)) + if not flag: + torch.save(weights, path) + else: + weights_new=[] + for i in range(0, len(weights), 4): + subset = weights[i+(flag-1)*2:i+(flag-1)*2+2] + weights_new.extend(subset) + torch.save(weights_new, path) + +def save_lora_as_json(model, path="./lora.json"): + weights = [] + for _up, _down in extract_lora_ups_down(model): + weights.append(_up.weight.detach().cpu().numpy().tolist()) + weights.append(_down.weight.detach().cpu().numpy().tolist()) + + import json + + with open(path, "w") as f: + json.dump(weights, f) + + +def save_safeloras_with_embeds( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Saves the Lora from multiple modules in a single safetensor file. + + modelmap is a dictionary of { + "module name": (module, target_replace_module) + } + """ + weights = {} + metadata = {} + + for name, (model, target_replace_module) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + for i, (_up, _down) in enumerate( + extract_lora_as_tensor(model, target_replace_module) + ): + rank = _down.shape[0] + + metadata[f"{name}:{i}:rank"] = str(rank) + weights[f"{name}:{i}:up"] = _up + weights[f"{name}:{i}:down"] = _down + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Converts the Lora from multiple pytorch .pt files into a single safetensor file. + + modelmap is a dictionary of { + "module name": (pytorch_model_path, target_replace_module, rank) + } + """ + + weights = {} + metadata = {} + + for name, (path, target_replace_module, r) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + lora = torch.load(path) + for i, weight in enumerate(lora): + is_up = i % 2 == 0 + i = i // 2 + + if is_up: + metadata[f"{name}:{i}:rank"] = str(r) + weights[f"{name}:{i}:up"] = weight + else: + weights[f"{name}:{i}:down"] = weight + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def parse_safeloras( + safeloras, +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: + """ + Converts a loaded safetensor file that contains a set of module Loras + into Parameters and other information + + Output is a dictionary of { + "module name": ( + [list of weights], + [list of ranks], + target_replacement_modules + ) + } + """ + loras = {} + metadata = safeloras.metadata() + + get_name = lambda k: k.split(":")[0] + + keys = list(safeloras.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) + + loras[name] = (weights, ranks, target) + + return loras + + +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() + + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + + +def load_safeloras(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras) + + +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + +def collapse_lora(model, alpha=1.0): + + for _module, name, _child_module in _find_modules( + model, + UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, + search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d], + ): + + if isinstance(_child_module, LoraInjectedLinear): + print("Collapsing Lin Lora in", name) + + _child_module.linear.weight = nn.Parameter( + _child_module.linear.weight.data + + alpha + * ( + _child_module.lora_up.weight.data + @ _child_module.lora_down.weight.data + ) + .type(_child_module.linear.weight.dtype) + .to(_child_module.linear.weight.device) + ) + + else: + print("Collapsing Conv Lora in", name) + _child_module.conv.weight = nn.Parameter( + _child_module.conv.weight.data + + alpha + * ( + _child_module.lora_up.weight.data.flatten(start_dim=1) + @ _child_module.lora_down.weight.data.flatten(start_dim=1) + ) + .reshape(_child_module.conv.weight.data.shape) + .type(_child_module.conv.weight.dtype) + .to(_child_module.conv.weight.device) + ) + + +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_lora_extended( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, + target_replace_module, + search_class=[ + nn.Linear, + nn.Conv2d, + nn.Conv3d, + LoraInjectedLinear, + LoraInjectedConv2d, + LoraInjectedConv3d, + ], + ): + + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue + + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + elif _child_module.__class__ == nn.Conv3d or( + _child_module.__class__ == LoraInjectedConv3d + ): + + if len(loras[0].shape) != 5: + continue + + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv3d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_safeloras(models, safeloras): + loras = parse_safeloras(safeloras) + + for name, (lora, ranks, target) in loras.items(): + model = getattr(models, name, None) + + if not model: + print(f"No model provided for {name}, contained in Lora") + continue + + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_modules( + model, search_class=[LoraInjectedLinear, LoraInjectedConv2d, LoraInjectedConv3d] + ): + if isinstance(_child_module, LoraInjectedLinear): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear( + _source.in_features, _source.out_features, bias is not None + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + else: + _source = _child_module.conv + weight, bias = _source.weight, _source.bias + + if isinstance(_source, nn.Conv2d): + _tmp = nn.Conv2d( + in_channels=_source.in_channels, + out_channels=_source.out_channels, + kernel_size=_source.kernel_size, + stride=_source.stride, + padding=_source.padding, + dilation=_source.dilation, + groups=_source.groups, + bias=bias is not None, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + if isinstance(_source, nn.Conv3d): + _tmp = nn.Conv3d( + _source.in_channels, + _source.out_channels, + bias=_source.bias is not None, + kernel_size=_source.kernel_size, + padding=_source.padding, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp + + +def monkeypatch_add_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + alpha: float = 1.0, + beta: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[LoraInjectedLinear] + ): + weight = _child_module.linear.weight + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) + + _module._modules[name].to(weight.device) + + +def tune_lora_scale(model, alpha: float = 1.0): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + _module.scale = alpha + + +def set_lora_diag(model, diag: torch.Tensor): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + _module.set_selector_from_diag(diag) + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def apply_learned_embed_in_clip( + learned_embeds, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + if isinstance(token, str): + trained_tokens = [token] + elif isinstance(token, list): + assert len(learned_embeds.keys()) == len( + token + ), "The number of tokens and the number of embeds should be the same" + trained_tokens = token + else: + trained_tokens = list(learned_embeds.keys()) + + for token in trained_tokens: + print(token) + embeds = learned_embeds[token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + num_added_tokens = tokenizer.add_tokens(token) + + i = 1 + if not idempotent: + while num_added_tokens == 0: + print(f"The tokenizer already contains the token {token}.") + token = f"{token[:-1]}-{i}>" + print(f"Attempting to add the token {token}.") + num_added_tokens = tokenizer.add_tokens(token) + i += 1 + elif num_added_tokens == 0 and idempotent: + print(f"The tokenizer already contains the token {token}.") + print(f"Replacing {token} embedding.") + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + + +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + +def patch_pipe( + pipe, + maybe_unet_path, + token: Optional[str] = None, + r: int = 4, + patch_unet=True, + patch_text=True, + patch_ti=True, + idempotent_token=True, + unet_target_replace_module=DEFAULT_TARGET_REPLACE, + text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +): + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = maybe_unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = maybe_unet_path[:-16] + ".pt" + else: + unet_path = maybe_unet_path + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) + + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + if patch_ti: + apply_learned_embed_in_clip( + tok_dict, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + return tok_dict + + +def train_patch_pipe(pipe, patch_unet, patch_text): + if patch_unet: + print("LoRA : Patching Unet") + collapse_lora(pipe.unet) + monkeypatch_remove_lora(pipe.unet) + + if patch_text: + print("LoRA : Patching text encoder") + + collapse_lora(pipe.text_encoder) + monkeypatch_remove_lora(pipe.text_encoder) + +@torch.no_grad() +def inspect_lora(model): + moved = {} + + for name, _module in model.named_modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d", "LoraInjectedConv3d"]: + ups = _module.lora_up.weight.data.clone() + downs = _module.lora_down.weight.data.clone() + + wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) + + dist = wght.flatten().abs().mean().item() + if name in moved: + moved[name].append(dist) + else: + moved[name] = [dist] + + return moved + + +def save_all( + unet, + text_encoder, + save_path, + placeholder_token_ids=None, + placeholder_tokens=None, + save_lora=True, + save_ti=True, + target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, +): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) + + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, + ) + print("Text Encoder saved to ", _text_lora_path(save_path)) + + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" + + loras = {} + embeds = {} + + if save_lora: + + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path) \ No newline at end of file diff --git a/utils/lora_handler.py b/utils/lora_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..8c654fe5903121a434db631ec6c6bd7f1d41d036 --- /dev/null +++ b/utils/lora_handler.py @@ -0,0 +1,294 @@ +import os +from logging import warnings +import torch +from typing import Union +from types import SimpleNamespace +from models.unet.unet_3d_condition import UNet3DConditionModel +from transformers import CLIPTextModel +from .convert_diffusers_to_original_ms_text_to_video import convert_unet_state_dict, convert_text_enc_state_dict_v20 + +from .lora import ( + extract_lora_ups_down, + inject_trainable_lora_extended, + save_lora_weight, + train_patch_pipe, + monkeypatch_or_replace_lora, + monkeypatch_or_replace_lora_extended +) + + +FILE_BASENAMES = ['unet', 'text_encoder'] +LORA_FILE_TYPES = ['.pt', '.safetensors'] +CLONE_OF_SIMO_KEYS = ['model', 'loras', 'target_replace_module', 'r'] +STABLE_LORA_KEYS = ['model', 'target_module', 'search_class', 'r', 'dropout', 'lora_bias'] + +lora_versions = dict( + stable_lora = "stable_lora", + cloneofsimo = "cloneofsimo" +) + +lora_func_types = dict( + loader = "loader", + injector = "injector" +) + +lora_args = dict( + model = None, + loras = None, + target_replace_module = [], + target_module = [], + r = 4, + search_class = [torch.nn.Linear], + dropout = 0, + lora_bias = 'none' +) + +LoraVersions = SimpleNamespace(**lora_versions) +LoraFuncTypes = SimpleNamespace(**lora_func_types) + +LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] +LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] + +def filter_dict(_dict, keys=[]): + if len(keys) == 0: + assert "Keys cannot empty for filtering return dict." + + for k in keys: + if k not in lora_args.keys(): + assert f"{k} does not exist in available LoRA arguments" + + return {k: v for k, v in _dict.items() if k in keys} + +class LoraHandler(object): + def __init__( + self, + version: LORA_VERSIONS = LoraVersions.cloneofsimo, + use_unet_lora: bool = False, + use_text_lora: bool = False, + save_for_webui: bool = False, + only_for_webui: bool = False, + lora_bias: str = 'none', + unet_replace_modules: list = None, + text_encoder_replace_modules: list = None + ): + self.version = version + self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) + self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) + self.lora_bias = lora_bias + self.use_unet_lora = use_unet_lora + self.use_text_lora = use_text_lora + self.save_for_webui = save_for_webui + self.only_for_webui = only_for_webui + self.unet_replace_modules = unet_replace_modules + self.text_encoder_replace_modules = text_encoder_replace_modules + self.use_lora = any([use_text_lora, use_unet_lora]) + + def is_cloneofsimo_lora(self): + return self.version == LoraVersions.cloneofsimo + + + def get_lora_func(self, func_type: LORA_FUNC_TYPES = LoraFuncTypes.loader): + + if self.is_cloneofsimo_lora(): + + if func_type == LoraFuncTypes.loader: + return monkeypatch_or_replace_lora_extended + + if func_type == LoraFuncTypes.injector: + return inject_trainable_lora_extended + + assert "LoRA Version does not exist." + + def check_lora_ext(self, lora_file: str): + return lora_file.endswith(tuple(LORA_FILE_TYPES)) + + def get_lora_file_path( + self, + lora_path: str, + model: Union[UNet3DConditionModel, CLIPTextModel] + ): + if os.path.exists(lora_path): + lora_filenames = [fns for fns in os.listdir(lora_path)] + is_lora = self.check_lora_ext(lora_path) + + is_unet = isinstance(model, UNet3DConditionModel) + is_text = isinstance(model, CLIPTextModel) + idx = 0 if is_unet else 1 + + base_name = FILE_BASENAMES[idx] + + for lora_filename in lora_filenames: + is_lora = self.check_lora_ext(lora_filename) + if not is_lora: + continue + + if base_name in lora_filename: + return os.path.join(lora_path, lora_filename) + + return None + + def handle_lora_load(self, file_name:str, lora_loader_args: dict = None): + self.lora_loader(**lora_loader_args) + print(f"Successfully loaded LoRA from: {file_name}") + + def load_lora(self, model, lora_path: str = '', lora_loader_args: dict = None,): + try: + lora_file = self.get_lora_file_path(lora_path, model) + + if lora_file is not None: + lora_loader_args.update({"lora_path": lora_file}) + self.handle_lora_load(lora_file, lora_loader_args) + + else: + print(f"Could not load LoRAs for {model.__class__.__name__}. Injecting new ones instead...") + + except Exception as e: + print(f"An error occurred while loading a LoRA file: {e}") + + def get_lora_func_args(self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias, scale): + return_dict = lora_args.copy() + + if self.is_cloneofsimo_lora(): + return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) + return_dict.update({ + "model": model, + "loras": self.get_lora_file_path(lora_path, model), + "target_replace_module": replace_modules, + "r": r, + "scale": scale, + "dropout_p": dropout, + }) + + return return_dict + + def do_lora_injection( + self, + model, + replace_modules, + bias='none', + dropout=0, + r=4, + lora_loader_args=None, + ): + REPLACE_MODULES = replace_modules + + params = None + negation = None + is_injection_hybrid = False + + if self.is_cloneofsimo_lora(): + is_injection_hybrid = True + injector_args = lora_loader_args + + params, negation = self.lora_injector(**injector_args) # inject_trainable_lora_extended + for _up, _down in extract_lora_ups_down( + model, + target_replace_module=REPLACE_MODULES): + + if all(x is not None for x in [_up, _down]): + print(f"Lora successfully injected into {model.__class__.__name__}.") + + break + + return params, negation, is_injection_hybrid + + return params, negation, is_injection_hybrid + + def add_lora_to_model(self, use_lora, model, replace_modules, dropout=0.0, lora_path='', r=16, scale=1.0): + + params = None + negation = None + + lora_loader_args = self.get_lora_func_args( + lora_path, + use_lora, + model, + replace_modules, + r, + dropout, + self.lora_bias, + scale + ) + + if use_lora: + params, negation, is_injection_hybrid = self.do_lora_injection( + model, + replace_modules, + bias=self.lora_bias, + lora_loader_args=lora_loader_args, + dropout=dropout, + r=r + ) + + if not is_injection_hybrid: + self.load_lora(model, lora_path=lora_path, lora_loader_args=lora_loader_args) + + params = model if params is None else params + return params, negation + + def save_cloneofsimo_lora(self, model, save_path, step, flag): + + def save_lora(model, name, condition, replace_modules, step, save_path, flag=None): + if condition and replace_modules is not None: + save_path = f"{save_path}/{step}_{name}.pt" + save_lora_weight(model, save_path, replace_modules, flag) + + save_lora( + model.unet, + FILE_BASENAMES[0], + self.use_unet_lora, + self.unet_replace_modules, + step, + save_path, + flag + ) + save_lora( + model.text_encoder, + FILE_BASENAMES[1], + self.use_text_lora, + self.text_encoder_replace_modules, + step, + save_path, + flag + ) + + # train_patch_pipe(model, self.use_unet_lora, self.use_text_lora) + + def save_lora_weights(self, model: None, save_path: str ='',step: str = '', flag=None): + save_path = f"{save_path}/lora" + os.makedirs(save_path, exist_ok=True) + + if self.is_cloneofsimo_lora(): + if any([self.save_for_webui, self.only_for_webui]): + warnings.warn( + """ + You have 'save_for_webui' enabled, but are using cloneofsimo's LoRA implemention. + Only 'stable_lora' is supported for saving to a compatible webui file. + """ + ) + self.save_cloneofsimo_lora(model, save_path, step, flag) + + + +def inject_spatial_loras(unet, use_unet_lora, lora_unet_dropout, lora_path, lora_rank, spatial_lora_num): + + lora_managers_spatial = [] + unet_lora_params_spatial_list = [] + for i in range(spatial_lora_num): + lora_manager_spatial = LoraHandler( + use_unet_lora=use_unet_lora, + unet_replace_modules=["Transformer2DModel"] + ) + lora_managers_spatial.append(lora_manager_spatial) + + unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model( + use_unet_lora, + unet, + lora_manager_spatial.unet_replace_modules, + lora_unet_dropout, + lora_path + '/spatial/lora/', + r=lora_rank + ) + unet_lora_params_spatial_list.append(unet_lora_params_spatial) + + return lora_managers_spatial, unet_lora_params_spatial_list, unet_negation_spatial \ No newline at end of file