Spaces:
Paused
Paused
| import gc | |
| import os | |
| import yaml | |
| import inspect | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from diffusers import DDIMScheduler | |
| from PIL import Image | |
| # from basicsr.utils import tensor2img | |
| from diffusers import AutoencoderKL | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from transformers import ( | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| AutoTokenizer, | |
| CLIPVisionModelWithProjection, | |
| CLIPImageProcessor, | |
| ClapTextModelWithProjection, | |
| RobertaTokenizer, | |
| RobertaTokenizerFast, | |
| SpeechT5HifiGan, | |
| ) | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from src.module.unet.unet_2d_condition import ( | |
| CustomUNet2DConditionModel, | |
| UNet2DConditionModel, | |
| ) | |
| from src.module.unet.estimator import _UNet2DConditionModel | |
| from src.utils.inversion import DDIMInversion | |
| from src.module.unet.attention_processor import ( | |
| IPAttnProcessor, | |
| AttnProcessor, | |
| Resampler, | |
| ) | |
| from src.model.sampler import Sampler | |
| from src.utils.audio_processing import extract_fbank, wav_to_fbank, TacotronSTFT, maybe_add_dimension | |
| import sys | |
| sys.path.append("src/module/tango") | |
| from tools.torch_tools import wav_to_fbank as tng_wav_to_fbank | |
| CWD = os.getcwd() | |
| class TangoPipeline: | |
| def __init__( | |
| self, | |
| sd_id="declare-lab/tango", | |
| NUM_DDIM_STEPS=100, | |
| precision=torch.float32, | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| **kwargs, | |
| ): | |
| import sys | |
| import json | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| sys.path.append("./src/module/tango") | |
| from tango2.models import AudioDiffusion | |
| from audioldm.audio.stft import TacotronSTFT as tng_TacotronSTFT | |
| from audioldm.variational_autoencoder import AutoencoderKL | |
| path = snapshot_download(repo_id=sd_id) | |
| vae_config = json.load(open("{}/vae_config.json".format(path))) | |
| stft_config = json.load(open("{}/stft_config.json".format(path))) | |
| main_config = json.load(open("{}/main_config.json".format(path))) | |
| main_config["unet_model_config_path"] = os.path.join( | |
| CWD, "src/module/tango", main_config["unet_model_config_path"] | |
| ) | |
| unet = self._set_unet2dconditional_model( | |
| CustomUNet2DConditionModel, | |
| unet_model_name=main_config["unet_model_name"], | |
| unet_model_config_path=main_config["unet_model_config_path"], | |
| ).to(device) | |
| feature_estimator = self._set_unet2dconditional_model( | |
| _UNet2DConditionModel, | |
| unet_model_name=main_config["unet_model_name"], | |
| unet_model_config_path=main_config["unet_model_config_path"], | |
| ).to(device) | |
| ##### Load pretrained model ##### | |
| vae = AutoencoderKL(**vae_config).to(device) | |
| vae.dtype = torch.float32 # avoid attribute missing | |
| stft = tng_TacotronSTFT(**stft_config).to(device) | |
| model = AudioDiffusion(**main_config).to(device) | |
| model.unet = unet # replace unet with the custom unet | |
| vae_weights = torch.load( | |
| "{}/pytorch_model_vae.bin".format(path), map_location=device | |
| ) | |
| stft_weights = torch.load( | |
| "{}/pytorch_model_stft.bin".format(path), map_location=device | |
| ) | |
| main_weights = torch.load( | |
| "{}/pytorch_model_main.bin".format(path), map_location=device | |
| ) | |
| vae.load_state_dict(vae_weights) | |
| stft.load_state_dict(stft_weights) | |
| model.load_state_dict(main_weights) | |
| unet_weights = {".".join(layer.split(".")[1:]): param for layer, param in model.named_parameters() if "unet" in layer} | |
| feature_estimator.load_state_dict(unet_weights) | |
| vae.eval() | |
| stft.eval() | |
| model.eval() | |
| feature_estimator.eval() | |
| # Free memeory | |
| del vae_weights | |
| del stft_weights | |
| del main_weights | |
| del unet_weights | |
| feature_estimator.scheduler = DDIMScheduler.from_pretrained( | |
| main_config["scheduler_name"], subfolder="scheduler" | |
| ) | |
| # Create pipeline for audio editing | |
| onestep_pipe = Sampler( | |
| vae=vae, | |
| tokenizer=model.tokenizer, | |
| text_encoder=model.text_encoder, | |
| unet=model.unet, | |
| feature_estimator=feature_estimator, | |
| scheduler=DDIMScheduler.from_pretrained( | |
| main_config["scheduler_name"], subfolder="scheduler" | |
| ), | |
| device=device, | |
| precision=precision, | |
| ) | |
| onestep_pipe.use_cross_attn = True | |
| gc.collect() | |
| onestep_pipe.enable_attention_slicing() | |
| if is_xformers_available(): | |
| onestep_pipe.feature_estimator.enable_xformers_memory_efficient_attention() | |
| onestep_pipe.enable_xformers_memory_efficient_attention() | |
| self.pipe = onestep_pipe | |
| self.fn_STFT = stft | |
| self.vae_scale_factor = vae_config["ddconfig"]["ch_mult"][-1] | |
| self.NUM_DDIM_STEPS = NUM_DDIM_STEPS | |
| self.num_tokens = 512 # flant5 | |
| self.precision = precision | |
| self.device = device | |
| # self.load_adapter() # replace the 1-st self-attn layer with cross-attn difference trajactory | |
| def _set_unet2dconditional_model( | |
| self, | |
| cls_obj: UNet2DConditionModel, | |
| *, | |
| unet_model_name=None, | |
| unet_model_config_path=None, | |
| ): | |
| assert ( | |
| unet_model_name is not None or unet_model_config_path is not None | |
| ), "Either UNet pretrain model name or a config file path is required" | |
| if unet_model_config_path: | |
| unet_config = cls_obj.load_config(unet_model_config_path) | |
| unet = cls_obj.from_config(unet_config, subfolder="unet") | |
| unet.set_from = "random" | |
| else: | |
| unet = cls_obj.from_pretrained(unet_model_name, subfolder="unet") | |
| unet.set_from = "pre-trained" | |
| unet.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4)) | |
| unet.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8)) | |
| return unet | |
| def decode_latents(self, latents): | |
| return self.pipe.vae.decode_first_stage(latents) | |
| def mel_spectrogram_to_waveform(self, mel_spectrogram): | |
| return self.pipe.vae.decode_to_waveform(mel_spectrogram) | |
| def get_fbank(self, audio_or_path, stft_cfg, return_intermediate=False): | |
| r"""Helper function to get fbank from audio file.""" | |
| if isinstance(audio_or_path, torch.Tensor): | |
| return maybe_add_dimension(audio_or_path, 4) | |
| if isinstance(audio_or_path, str): | |
| fbank, log_stft, wav = tng_wav_to_fbank( | |
| [audio_or_path], | |
| fn_STFT=self.fn_STFT, | |
| target_length=stft_cfg.filter_length, | |
| ) | |
| fbank = maybe_add_dimension(fbank, 4) # (B,C,T,F) | |
| if return_intermediate: | |
| return fbank, log_stft, wav | |
| return fbank | |
| def encode_fbank(self, fbank): | |
| return self.pipe.vae.get_first_stage_encoding( | |
| self.pipe.vae.encode_first_stage(fbank) | |
| ) | |
| def fbank2latent(self, fbank): | |
| latent = self.encode_fbank(fbank) | |
| return latent | |
| def ddim_inv(self, latent, prompt, emb_im=None, save_kv=True, mode="mix", prediction_type="v_prediction"): | |
| ddim_inv = DDIMInversion(model=self.pipe, NUM_DDIM_STEPS=self.NUM_DDIM_STEPS) | |
| ddim_latents = ddim_inv.invert( | |
| ddim_latents=latent.unsqueeze(2), prompt=prompt, emb_im=emb_im, | |
| save_kv=save_kv, mode=mode, prediction_type=prediction_type, | |
| ) | |
| return ddim_latents | |
| def init_proj(self, precision): | |
| image_proj_model = Resampler( | |
| dim=self.pipe.unet.config.cross_attention_dim, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=self.num_tokens, | |
| embedding_dim=self.image_encoder.config.hidden_size, | |
| output_dim=self.pipe.unet.config.cross_attention_dim, | |
| ff_mult=4, | |
| ).to("cuda", dtype=precision) | |
| return image_proj_model | |
| def load_adapter(self): | |
| scale = 1.0 | |
| attn_procs = {} | |
| for name in self.pipe.unet.attn_processors.keys(): | |
| cross_attention_dim = None | |
| if name.startswith("mid_block"): | |
| hidden_size = self.pipe.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[ | |
| block_id | |
| ] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = self.pipe.unet.config.block_out_channels[block_id] | |
| # Only the first self-attention should be used for cross-attend different trojactory | |
| if name.endswith("attn1.processor"): | |
| attn_procs[name] = AttnProcessor() | |
| else: | |
| attn_procs[name] = IPAttnProcessor( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| scale=scale, | |
| num_tokens=self.num_tokens, | |
| ).to("cuda", dtype=self.precision) | |
| self.pipe.unet.set_attn_processor(attn_procs) | |
| class AudioLDMPipeline: | |
| def __init__( | |
| self, | |
| sd_id="cvssp/audioldm-l-full", | |
| ip_id="cvssp/audioldm-l-full", | |
| NUM_DDIM_STEPS=50, | |
| precision=torch.float32, | |
| ip_scale=0, | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ): | |
| onestep_pipe = Sampler( | |
| vae=AutoencoderKL.from_pretrained( | |
| sd_id, subfolder="vae", torch_dtype=precision | |
| ), | |
| tokenizer=RobertaTokenizerFast.from_pretrained( | |
| sd_id, subfolder="tokenizer" | |
| ), | |
| text_encoder=ClapTextModelWithProjection.from_pretrained( | |
| sd_id, subfolder="text_encoder", torch_dtype=precision | |
| ), | |
| unet=CustomUNet2DConditionModel.from_pretrained( | |
| sd_id, subfolder="unet", torch_dtype=precision | |
| ), | |
| feature_estimator=_UNet2DConditionModel.from_pretrained( | |
| sd_id, | |
| subfolder="unet", | |
| vae=None, | |
| text_encoder=None, | |
| tokenizer=None, | |
| scheduler=DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler"), | |
| safety_checker=None, | |
| feature_extractor=None, | |
| ), | |
| scheduler=DDIMScheduler.from_pretrained(sd_id, subfolder="scheduler"), | |
| device=device, | |
| precision=precision, | |
| ) | |
| onestep_pipe.vocoder = SpeechT5HifiGan.from_pretrained( | |
| sd_id, subfolder="vocoder", torch_dtype=precision | |
| ) | |
| onestep_pipe.use_cross_attn = False | |
| gc.collect() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| onestep_pipe = onestep_pipe.to(device) | |
| onestep_pipe.vocoder.to(device) | |
| onestep_pipe.enable_attention_slicing() | |
| if is_xformers_available(): | |
| onestep_pipe.feature_estimator.enable_xformers_memory_efficient_attention() | |
| onestep_pipe.enable_xformers_memory_efficient_attention() | |
| self.pipe = onestep_pipe | |
| self.vae_scale_factor = 2 ** (len(self.pipe.vae.config.block_out_channels) - 1) | |
| self.NUM_DDIM_STEPS = NUM_DDIM_STEPS | |
| self.precision = precision | |
| self.device = device | |
| self.num_tokens = 64 | |
| # This is fixed as per pretrained model | |
| self.fn_STFT = TacotronSTFT( | |
| filter_length=1024, | |
| hop_length=160, | |
| win_length=1024, | |
| n_mel_channels=64, | |
| sampling_rate=16000, | |
| mel_fmin=0, | |
| mel_fmax=8000, | |
| ) | |
| # self.load_adapter() | |
| def decode_latents(self, latents): | |
| latents = 1 / self.pipe.vae.config.scaling_factor * latents | |
| mel_spectrogram = self.pipe.vae.decode(latents).sample | |
| return mel_spectrogram | |
| def mel_spectrogram_to_waveform(self, mel_spectrogram): | |
| if mel_spectrogram.dim() == 4: | |
| mel_spectrogram = mel_spectrogram.squeeze(1) | |
| waveform = self.pipe.vocoder( | |
| mel_spectrogram.to(device=self.device, dtype=self.precision) | |
| ) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| waveform = waveform.cpu().float() | |
| return waveform | |
| def fbank2latent(self, fbank): | |
| latent = self.encode_fbank(fbank) | |
| return latent | |
| def get_fbank(self, audio_or_path, stft_cfg, return_intermediate=False): | |
| r"""Helper function to get fbank from audio file.""" | |
| if isinstance(audio_or_path, torch.Tensor): | |
| return maybe_add_dimension(audio_or_path, 3) | |
| if isinstance(audio_or_path, str): | |
| fbank, log_stft, wav = extract_fbank( | |
| audio_or_path, | |
| fn_STFT=self.fn_STFT, | |
| target_length=stft_cfg.filter_length, | |
| hop_size=stft_cfg.hop_length, | |
| ) | |
| fbank = maybe_add_dimension(fbank, 3) # (C,T,F) | |
| if return_intermediate: | |
| return fbank, log_stft, wav | |
| return fbank | |
| def wav2fbank(self, wav, target_length): | |
| fbank, log_magnitudes_stft = wav_to_fbank(wav, target_length, self.fn_STFT) | |
| return fbank, log_magnitudes_stft | |
| def encode_fbank(self, fbank): | |
| latent = self.pipe.vae.encode(fbank)["latent_dist"].mean | |
| # NOTE: Scale the noise latent | |
| latent = latent * self.pipe.scheduler.init_noise_sigma | |
| return latent | |
| def ddim_inv(self, latent, prompt, emb_im=None, save_kv=True, mode="mix", prediction_type="epsilon"): | |
| ddim_inv = DDIMInversion(model=self.pipe, NUM_DDIM_STEPS=self.NUM_DDIM_STEPS) | |
| ddim_latents = ddim_inv.invert( | |
| ddim_latents=latent.unsqueeze(2), prompt=prompt, emb_im=emb_im, | |
| save_kv=save_kv, mode=mode, prediction_type=prediction_type | |
| ) | |
| return ddim_latents | |
| def init_proj(self, precision): | |
| image_proj_model = Resampler( | |
| dim=self.pipe.unet.config.cross_attention_dim, | |
| depth=4, | |
| dim_head=64, | |
| heads=12, | |
| num_queries=self.num_tokens, | |
| embedding_dim=self.image_encoder.config.hidden_size, | |
| output_dim=self.pipe.unet.config.cross_attention_dim, | |
| ff_mult=4, | |
| ).to("cuda", dtype=precision) | |
| return image_proj_model | |
| # @torch.inference_mode() | |
| # def get_image_embeds(self, pil_image): | |
| # if isinstance(pil_image, Image.Image): | |
| # pil_image = [pil_image] | |
| # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | |
| # clip_image = clip_image.to('cuda', dtype=self.precision) | |
| # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
| # image_prompt_embeds = self.image_proj_model(clip_image_embeds) | |
| # uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2].detach() | |
| # uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds).detach() | |
| # return image_prompt_embeds, uncond_image_prompt_embeds | |
| def load_adapter(self): | |
| scale = 1.0 | |
| attn_procs = {} | |
| for name in self.pipe.unet.attn_processors.keys(): | |
| cross_attention_dim = None | |
| if name.startswith("mid_block"): | |
| hidden_size = self.pipe.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[ | |
| block_id | |
| ] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = self.pipe.unet.config.block_out_channels[block_id] | |
| # Only the first self-attention should be used for cross-attend different trojactory | |
| if name.endswith("attn1.processor"): | |
| attn_procs[name] = AttnProcessor() | |
| else: | |
| attn_procs[name] = IPAttnProcessor( | |
| hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim, | |
| scale=scale, | |
| num_tokens=self.num_tokens, | |
| ).to("cuda", dtype=self.precision) | |
| self.pipe.unet.set_attn_processor(attn_procs) | |
| # def load_adapter(self, model_path, scale=1.0): | |
| # from src.unet.attention_processor import IPAttnProcessor, AttnProcessor, Resampler | |
| # attn_procs = {} | |
| # for name in self.pipe.unet.attn_processors.keys(): | |
| # cross_attention_dim = None if name.endswith("attn1.processor") else self.pipe.unet.config.cross_attention_dim | |
| # if name.startswith("mid_block"): | |
| # hidden_size = self.pipe.unet.config.block_out_channels[-1] | |
| # elif name.startswith("up_blocks"): | |
| # block_id = int(name[len("up_blocks.")]) | |
| # hidden_size = list(reversed(self.pipe.unet.config.block_out_channels))[block_id] | |
| # elif name.startswith("down_blocks"): | |
| # block_id = int(name[len("down_blocks.")]) | |
| # hidden_size = self.pipe.unet.config.block_out_channels[block_id] | |
| # if cross_attention_dim is None: | |
| # attn_procs[name] = AttnProcessor() | |
| # else: | |
| # attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, | |
| # scale=scale,num_tokens= self.num_tokens).to('cuda', dtype=self.precision) | |
| # self.pipe.unet.set_attn_processor(attn_procs) | |
| # state_dict = torch.load(model_path, map_location="cpu") | |
| # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
| # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) | |
| # ip_layers.load_state_dict(state_dict["ip_adapter"], strict=True) | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs | |
| def prepare_extra_step_kwargs(self, generator, eta): | |
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
| # and should be between [0, 1] | |
| accepts_eta = "eta" in set( | |
| inspect.signature(self.pipe.scheduler.step).parameters.keys() | |
| ) | |
| extra_step_kwargs = {} | |
| if accepts_eta: | |
| extra_step_kwargs["eta"] = eta | |
| # check if the scheduler accepts generator | |
| accepts_generator = "generator" in set( | |
| inspect.signature(self.pipe.scheduler.step).parameters.keys() | |
| ) | |
| if accepts_generator: | |
| extra_step_kwargs["generator"] = generator | |
| return extra_step_kwargs | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim | |
| def prepare_latents( | |
| self, | |
| batch_size, | |
| num_channels_latents, | |
| height, | |
| dtype, | |
| device, | |
| generator, | |
| latents=None, | |
| ): | |
| shape = ( | |
| batch_size, | |
| num_channels_latents, | |
| height // self.vae_scale_factor, | |
| self.pipe.vocoder.config.model_in_dim // self.vae_scale_factor, | |
| ) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| latents = randn_tensor( | |
| shape, generator=generator, device=device, dtype=dtype | |
| ) | |
| else: | |
| latents = latents.to(device) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * self.pipe.scheduler.init_noise_sigma | |
| return latents | |
| if __name__ == "__main__": | |
| # pipeline = AudioLDMPipeline( | |
| # sd_id="cvssp/audioldm-l-full", ip_id="cvssp/audioldm-l-full", NUM_DDIM_STEPS=50 | |
| # ) | |
| pipeline = TangoPipeline( | |
| sd_id="declare-lab/tango", | |
| ip_id="declare-lab/tango", | |
| NUM_DDIM_STEPS=50, | |
| precision=torch.float16, | |
| ) | |
| print(pipeline.__dict__) | |