Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import os | |
| from typing import List | |
| from diffusers import StableDiffusionPipeline | |
| from diffusers.pipelines.controlnet import MultiControlNetModel | |
| from PIL import Image | |
| from safetensors import safe_open | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from foleycrafter.models.adapters.resampler import Resampler | |
| from foleycrafter.models.adapters.utils import is_torch2_available | |
| class IPAdapter(torch.nn.Module): | |
| """IP-Adapter""" | |
| def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): | |
| super().__init__() | |
| self.unet = unet | |
| self.image_proj_model = image_proj_model | |
| self.adapter_modules = adapter_modules | |
| if ckpt_path is not None: | |
| self.load_from_checkpoint(ckpt_path) | |
| def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): | |
| ip_tokens = self.image_proj_model(image_embeds) | |
| encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) | |
| # Predict the noise residual | |
| noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
| return noise_pred | |
| def load_from_checkpoint(self, ckpt_path: str): | |
| # Calculate original checksums | |
| orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| # Load state dict for image_proj_model and adapter_modules | |
| self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
| self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) | |
| # Calculate new checksums | |
| new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| # Verify if the weights have changed | |
| assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" | |
| assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" | |
| print(f"Successfully loaded weights from checkpoint {ckpt_path}") | |
| class VideoProjModel(torch.nn.Module): | |
| """Projection Model""" | |
| def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50): | |
| super().__init__() | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| self.video_frame = video_frame | |
| def forward(self, image_embeds): | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |
| class ImageProjModel(torch.nn.Module): | |
| """Projection Model""" | |
| def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
| super().__init__() | |
| self.cross_attention_dim = cross_attention_dim | |
| self.clip_extra_context_tokens = clip_extra_context_tokens | |
| self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
| self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds): | |
| embeds = image_embeds | |
| clip_extra_context_tokens = self.proj(embeds).reshape( | |
| -1, self.clip_extra_context_tokens, self.cross_attention_dim | |
| ) | |
| clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
| return clip_extra_context_tokens | |
| class MLPProjModel(torch.nn.Module): | |
| """SD model with image prompt""" | |
| def zero_initialize(module): | |
| for param in module.parameters(): | |
| param.data.zero_() | |
| def zero_initialize_last_layer(module): | |
| last_layer = None | |
| for module_name, layer in module.named_modules(): | |
| if isinstance(layer, torch.nn.Linear): | |
| last_layer = layer | |
| if last_layer is not None: | |
| last_layer.weight.data.zero_() | |
| last_layer.bias.data.zero_() | |
| def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): | |
| super().__init__() | |
| self.proj = torch.nn.Sequential( | |
| torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), | |
| torch.nn.LayerNorm(cross_attention_dim) | |
| ) | |
| # zero initialize the last layer | |
| # self.zero_initialize_last_layer() | |
| def forward(self, image_embeds): | |
| clip_extra_context_tokens = self.proj(image_embeds) | |
| return clip_extra_context_tokens | |
| class V2AMapperMLP(torch.nn.Module): | |
| def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4): | |
| super().__init__() | |
| self.proj = torch.nn.Sequential( | |
| torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult), | |
| torch.nn.GELU(), | |
| torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim), | |
| torch.nn.LayerNorm(cross_attention_dim) | |
| ) | |
| def forward(self, image_embeds): | |
| clip_extra_context_tokens = self.proj(image_embeds) | |
| return clip_extra_context_tokens | |
| class TimeProjModel(torch.nn.Module): | |
| def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64): | |
| super().__init__() | |
| self.positive_len = positive_len | |
| self.out_dim = out_dim | |
| self.position_dim = frame_nums | |
| if isinstance(out_dim, tuple): | |
| out_dim = out_dim[0] | |
| if feature_type == "text-only": | |
| self.linears = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| elif feature_type == "text-image": | |
| self.linears_text = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.linears_image = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
| def forward( | |
| self, | |
| boxes, | |
| masks, | |
| positive_embeddings=None, | |
| ): | |
| masks = masks.unsqueeze(-1) | |
| # # embedding position (it may includes padding as placeholder) | |
| # xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C | |
| # # learnable null embedding | |
| # xyxy_null = self.null_position_feature.view(1, 1, -1) | |
| # # replace padding with learnable null embedding | |
| # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
| time_embeds = boxes | |
| # positionet with text only information | |
| if positive_embeddings is not None: | |
| # learnable null embedding | |
| positive_null = self.null_positive_feature.view(1, 1, -1) | |
| # replace padding with learnable null embedding | |
| positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null | |
| objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1)) | |
| # positionet with text and image infomation | |
| else: | |
| raise NotImplementedError | |
| return objs |