import sys sys.path.append("src") import torch import logging import torch.nn as nn from qa_mdt.audioldm_train.modules.clap.open_clip import create_model from qa_mdt.audioldm_train.modules.clap.training.data import get_audio_features import torchaudio from transformers import ( RobertaTokenizer, AutoTokenizer, T5EncoderModel, MT5EncoderModel, ) import torch.nn.functional as F from qa_mdt.audioldm_train.modules.audiomae.AudioMAE import Vanilla_AudioMAE from qa_mdt.audioldm_train.modules.phoneme_encoder.encoder import TextEncoder from transformers import SpeechT5Processor, AutoTokenizer, GPT2Model, GPT2Tokenizer from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithTextPrenet from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import CLAP2AudioMAE from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.sequence_input import ( Sequence2AudioMAE, ) import numpy as np from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import Prenet import json with open('./qa_mdt/offset_pretrained_checkpoints.json', 'r') as config_file: config_data = json.load(config_file) """ The model forward function can return three types of data: 1. tensor: used directly as conditioning signal 2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. 3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. The output shape for the cross attention condition should be: x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] All the returned data, in which will be used as diffusion input, will need to be in float type """ class GPT2WordEmbedding(nn.Module): def __init__(self): super().__init__() # self.tokenizer = AutoTokenizer.from_pretrained("gpt2") self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") self.tokenizer.pad_token = self.tokenizer.eos_token self.model = GPT2Model.from_pretrained("gpt2").wte self.device = None def get_unconditional_condition(self, batchsize): unconditional_condition = ["random"] * batchsize return self(unconditional_condition) def forward(self, text): assert isinstance(text, list) if self.device is None: self.device = next(self.model.parameters()).device tokenization_result = self.tokenizer(text, return_tensors="pt", padding=True) input_ids, attn_mask = tokenization_result["input_ids"].to( self.device ), tokenization_result["attention_mask"].to(self.device) input_embed = self.model(input_ids.long()) return [input_embed, attn_mask] class ConcateBandWidthCond(nn.Module): def __init__(self, latent_t_size, latent_f_size): super().__init__() self.placeholder = nn.Linear(1, 1) self.latent_t_size = latent_t_size self.latent_f_size = latent_f_size self.device = None def get_unconditional_condition(self, batchsize): return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( self.device ) def forward(self, mel_spec_bandwidth_cond_extra_channel): if self.device is None: self.device = mel_spec_bandwidth_cond_extra_channel.device return mel_spec_bandwidth_cond_extra_channel class BandwidthEncoder(nn.Module): def __init__(self): super().__init__() self.emb = nn.Embedding(1000, 128) nn.init.normal_(self.emb.weight, 0.0, 128**-0.5) self.linear_bandwidth = nn.Linear(128, 128) self.unconditional_condition = torch.zeros((1, 256)) self.device = None def get_unconditional_condition(self, batchsize): return self.unconditional_condition.expand(batchsize, 256) def forward(self, bandwidth): if self.device is None: self.device = next(self.linear_bandwidth.parameters()).device self.unconditional_condition = self.unconditional_condition.to(self.device) # freq_energy_percentile lower_cutoff, higher_cutoff = bandwidth[..., 0], bandwidth[..., 1] # lower_cutoff, higher_cutoff = lower_cutoff*0+5, higher_cutoff*0+300 lower_cutoff_emb = self.linear_bandwidth(self.emb(lower_cutoff.long())) higher_cutoff_emb = self.linear_bandwidth(self.emb(higher_cutoff.long())) cutoff_emb = torch.cat([lower_cutoff_emb, higher_cutoff_emb], dim=-1) # [bs, 256] return cutoff_emb class SpeechT5TextEncoder(nn.Module): def __init__(self): super().__init__() self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") self.model = SpeechT5EncoderWithTextPrenet.from_pretrained( "microsoft/speecht5_tts" ) for p in self.model.parameters(): p.requires_grad = False self.model.eval() # Required def get_unconditional_condition(self, batchsize): device = self.model.device hidden_state = torch.zeros((batchsize, 1, 768)).to(device) attention_mask = torch.ones((batchsize, 1)).to(device) return [hidden_state.float(), attention_mask.float()] def forward(self, text): with torch.no_grad(): device = self.model.device inputs = self.processor(text=text, return_tensors="pt", padding=True) input_ids, attention_mask = inputs["input_ids"].to(device), inputs[ "attention_mask" ].to(device) emb = self.model(input_ids, attention_mask) emb = emb.last_hidden_state.detach() return [emb.float(), attention_mask.float()] class PhonemeEncoder(nn.Module): def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): super().__init__() """ encoder = PhonemeEncoder(40) data = torch.randint(0, 39, (2, 250)) output = encoder(data) import ipdb;ipdb.set_trace() """ assert pad_token_id is not None self.device = None self.PAD_LENGTH = int(pad_length) self.pad_token_id = pad_token_id self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) self.text_encoder = TextEncoder( n_vocab=vocabs_size, out_channels=192, hidden_channels=192, filter_channels=768, n_heads=2, n_layers=6, kernel_size=3, p_dropout=0.1, ) self.learnable_positional_embedding = torch.nn.Parameter( torch.zeros((1, 192, self.PAD_LENGTH)) ) # [batchsize, seqlen, padlen] self.learnable_positional_embedding.requires_grad = True # Required def get_unconditional_condition(self, batchsize): unconditional_tokens = self.pad_token_sequence.expand( batchsize, self.PAD_LENGTH ) return self(unconditional_tokens) # Need to return float type # def get_unconditional_condition(self, batchsize): # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) # return [hidden_state, attention_mask] # Need to return float type def _get_src_mask(self, phoneme): src_mask = phoneme != self.pad_token_id return src_mask def _get_src_length(self, phoneme): src_mask = self._get_src_mask(phoneme) length = torch.sum(src_mask, dim=-1) return length # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): # # src_length: [bs] # # text_emb: [bs, 192, pad_length] # # attention_mask: [bs, pad_length] # mask = src_length[..., None, None] > 1 # text_emb = text_emb * mask # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 # return text_emb, attention_mask def forward(self, phoneme_idx): if self.device is None: self.device = self.learnable_positional_embedding.device self.pad_token_sequence = self.pad_token_sequence.to(self.device) src_length = self._get_src_length(phoneme_idx) text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) text_emb = text_emb + self.learnable_positional_embedding # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) return [ text_emb.permute(0, 2, 1), text_emb_mask.squeeze(1), ] # [2, 250, 192], [2, 250] class FlanT5HiddenState(nn.Module): """ llama = FlanT5HiddenState() data = ["","this is not an empty sentence"] encoder_hidden_states = llama(data) import ipdb;ipdb.set_trace() """ def __init__( self, text_encoder_name=config_data['flan_t5'], freeze_text_encoder=True ): super().__init__() self.freeze_text_encoder = freeze_text_encoder ## MODIFIED self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") if freeze_text_encoder: self.model.eval() for p in self.model.parameters(): p.requires_grad = False else: print("=> The text encoder is learnable") self.empty_hidden_state_cfg = None self.device = None # Required def get_unconditional_condition(self, batchsize): param = next(self.model.parameters()) if self.freeze_text_encoder: assert param.requires_grad == False # device = param.device if self.empty_hidden_state_cfg is None: self.empty_hidden_state_cfg, _ = self([""]) hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() attention_mask = ( torch.ones((batchsize, hidden_state.size(1))) .to(hidden_state.device) .float() ) return [hidden_state, attention_mask] # Need to return float type def forward(self, batch): param = next(self.model.parameters()) if self.freeze_text_encoder: assert param.requires_grad == False if self.device is None: self.device = param.device # print("Manually change text") # for i in range(len(batch)): # batch[i] = "dog barking" try: return self.encode_text(batch) except Exception as e: print(e, batch) logging.exception("An error occurred: %s", str(e)) def encode_text(self, prompt): device = self.model.device batch = self.tokenizer( prompt, max_length=128, # self.tokenizer.model_max_length padding=True, truncation=True, return_tensors="pt", ) input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( device ) # Get text encoding if self.freeze_text_encoder: with torch.no_grad(): encoder_hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask )[0] else: encoder_hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask )[0] return [ encoder_hidden_states.detach(), attention_mask.float(), ] # Attention mask == 1 means usable token class FlanT5HiddenStatePaddedSameLength(nn.Module): """ llama = FlanT5HiddenState() data = ["","this is not an empty sentence"] encoder_hidden_states = llama(data) import ipdb;ipdb.set_trace() """ def __init__( self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True ): super().__init__() self.freeze_text_encoder = freeze_text_encoder self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") if freeze_text_encoder: self.model.eval() for p in self.model.parameters(): p.requires_grad = False else: print("=> The text encoder is learnable") self.empty_hidden_state_cfg = None self.device = None # Required def get_unconditional_condition(self, batchsize): param = next(self.model.parameters()) if self.freeze_text_encoder: assert param.requires_grad == False # device = param.device if self.empty_hidden_state_cfg is None: self.empty_hidden_state_cfg, _ = self([""]) hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() attention_mask = ( torch.ones((batchsize, hidden_state.size(1))) .to(hidden_state.device) .float() ) return [hidden_state, attention_mask] # Need to return float type def forward(self, batch): param = next(self.model.parameters()) if self.freeze_text_encoder: assert param.requires_grad == False if self.device is None: self.device = param.device # print("Manually change text") # for i in range(len(batch)): # batch[i] = "dog barking" try: text_embed = self.encode_text(batch) return text_embed except Exception as e: print(e, batch) logging.exception("An error occurred: %s", str(e)) def encode_text(self, prompt): device = self.model.device batch = self.tokenizer( prompt, max_length=128, padding="max_length", truncation=True, return_tensors="pt", ) input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( device ) # Get text encoding if self.freeze_text_encoder: with torch.no_grad(): encoder_hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask )[0] else: encoder_hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask )[0] return [ encoder_hidden_states.detach(), attention_mask.float(), ] # Attention mask == 1 means usable token class CLAPGenAudioMAECond(CLAP2AudioMAE): def __init__( self, cond_stage_config, learnable=True, pretrained_path=None, use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT use_gt_mae_prob=None, ): # The prob of using AudioMAE GT super().__init__(base_learning_rate=1e-5, cond_stage_config=cond_stage_config) assert use_gt_mae_output is not None and use_gt_mae_prob is not None if pretrained_path is not None: print("Reload CLAPGenAudioMAECond from %s" % pretrained_path) state_dict = torch.load(pretrained_path)["state_dict"] self.load_state_dict(state_dict) self.use_gt_mae_output = use_gt_mae_output self.use_gt_mae_prob = use_gt_mae_prob self.learnable = learnable if not learnable: # Only optimize the GPT2 model for p in self.model.parameters(): p.requires_grad = False self.eval() # Required def get_unconditional_condition(self, batchsize): return_dict = self.cfg_uncond(batchsize) return return_dict def forward(self, batch): # The conditional module can return both tensor or dictionaries # The returned tensor will be corresponding to the cond_stage_key # The returned dict will have keys that correspond to the cond_stage_key ret_dict = {} if self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob: cond_dict = self.get_input(batch) # Used as condition ret_dict["crossattn_clap_to_audiomae_feature"] = [ cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), ] # Input sequence and mask else: # Used as condition input_embeds, cond_dict = self.generate(batch) input_embeds_mask = ( torch.ones((input_embeds.size(0), input_embeds.size(1))) .to(input_embeds.device) .float() ) ret_dict["crossattn_clap_to_audiomae_feature"] = [ input_embeds, input_embeds_mask, ] # Input sequence and mask # If the following two keys are not in cond_stage_key, then they will not be used as condition ret_dict["film_clap_cond1"] = cond_dict[ "film_clap_cond1" ] # the clap target latent ret_dict["crossattn_audiomae_pooled"] = cond_dict[ "crossattn_audiomae_pooled" ] # audiomae target latent if self.learnable and self.training: loss = self.training_step(batch, cond_dict=cond_dict) ret_dict["noncond_loss_clap2audiomae"] = loss return ret_dict class SequenceGenAudioMAECond(Sequence2AudioMAE): def __init__( self, cond_stage_config, base_learning_rate, sequence_gen_length, sequence_input_key, sequence_input_embed_dim, batchsize, always_output_audiomae_gt=False, pretrained_path=None, force_reload_pretrain_avoid_overwrite=False, learnable=True, use_warmup=True, use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT use_gt_mae_prob=None, ): # The prob of using AudioMAE GT if use_warmup: print( "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." ) use_warmup = False super().__init__( base_learning_rate=base_learning_rate, cond_stage_config=cond_stage_config, sequence_gen_length=sequence_gen_length, sequence_input_key=sequence_input_key, use_warmup=use_warmup, sequence_input_embed_dim=sequence_input_embed_dim, batchsize=batchsize, ) assert use_gt_mae_output is not None and use_gt_mae_prob is not None self.always_output_audiomae_gt = always_output_audiomae_gt self.force_reload_pretrain_avoid_overwrite = ( force_reload_pretrain_avoid_overwrite ) self.pretrained_path = pretrained_path if self.force_reload_pretrain_avoid_overwrite: self.is_reload = False else: self.is_reload = True self.load_pretrain_model() self.use_gt_mae_output = use_gt_mae_output self.use_gt_mae_prob = use_gt_mae_prob self.learnable = learnable if not learnable: # Only optimize the GPT2 model for p in self.model.parameters(): p.requires_grad = False self.eval() def load_pretrain_model(self): if self.pretrained_path is not None: print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) state_dict = torch.load(self.pretrained_path)["state_dict"] self.load_state_dict(state_dict) # Required def get_unconditional_condition(self, batchsize): return_dict = self.cfg_uncond(batchsize) return_dict["crossattn_audiomae_generated"] = [ return_dict["crossattn_audiomae_pooled"][0], torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), ] return return_dict def forward(self, batch): # The conditional module can return both tensor or dictionaries # The returned tensor will be corresponding to the cond_stage_key # The returned dict will have keys that correspond to the cond_stage_key ret_dict = {} if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: self.load_pretrain_model() self.is_reload = True self.check_module_param_update() if self.always_output_audiomae_gt or ( self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob ): cond_dict = self.get_input(batch) ret_dict["crossattn_audiomae_generated"] = [ cond_dict["crossattn_audiomae_pooled"][0], torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), ] # Input sequence and mask # _, output = self.training_step(batch, cond_dict=cond_dict, return_output=True) # ret_dict["crossattn_audiomae_generated"] = [output, torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask else: if not self.training: print("--------------> Generate !!!!!!!!!!!!") input_embeds, cond_dict = self.generate(batch) # print("Generate Partial!!!!"); input_embeds, cond_dict = self.generate_partial(batch) input_embeds_mask = ( torch.ones((input_embeds.size(0), input_embeds.size(1))) .to(input_embeds.device) .float() ) ret_dict["crossattn_audiomae_generated"] = [ input_embeds, input_embeds_mask, ] # Input sequence and mask # If the following two keys are not in cond_stage_key, then they will not be used as condition for key in cond_dict.keys(): ret_dict[key] = cond_dict[key] if self.learnable and self.training: loss = self.training_step(batch, cond_dict=cond_dict) ret_dict["noncond_loss_clap2audiomae"] = loss return ret_dict class SequenceGenAudioMAECond_AudioMAE_PostNet(Sequence2AudioMAE): def __init__( self, cond_stage_config, base_learning_rate, sequence_gen_length, sequence_input_key, sequence_input_embed_dim, batchsize, always_output_audiomae_gt=False, pretrained_path=None, use_ar_gen_loss=False, force_reload_pretrain_avoid_overwrite=False, learnable=True, use_warmup=True, use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT use_gt_mae_prob=None, ): # The prob of using AudioMAE GT if use_warmup: print( "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." ) use_warmup = False super().__init__( base_learning_rate=base_learning_rate, cond_stage_config=cond_stage_config, sequence_gen_length=sequence_gen_length, sequence_input_key=sequence_input_key, use_ar_gen_loss=use_ar_gen_loss, use_warmup=use_warmup, sequence_input_embed_dim=sequence_input_embed_dim, batchsize=batchsize, ) assert use_gt_mae_output is not None and use_gt_mae_prob is not None self.always_output_audiomae_gt = always_output_audiomae_gt self.force_reload_pretrain_avoid_overwrite = ( force_reload_pretrain_avoid_overwrite ) self.pretrained_path = pretrained_path if self.force_reload_pretrain_avoid_overwrite: self.is_reload = False else: self.is_reload = True self.load_pretrain_model() self.prenet = Prenet(in_dim=768, sizes=[768, 768, 768], dropout_rate=0.5) self.use_gt_mae_output = use_gt_mae_output self.use_gt_mae_prob = use_gt_mae_prob self.learnable = learnable if not learnable: # Only optimize the GPT2 model for p in self.model.parameters(): p.requires_grad = False self.eval() def load_pretrain_model(self): if self.pretrained_path is not None: print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) state_dict = torch.load(self.pretrained_path)["state_dict"] self.load_state_dict(state_dict) # Required def get_unconditional_condition(self, batchsize): return_dict = self.cfg_uncond(batchsize) return_dict["crossattn_audiomae_generated"] = [ return_dict["crossattn_audiomae_pooled"][0], torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), ] return return_dict def forward(self, batch): # The conditional module can return both tensor or dictionaries # The returned tensor will be corresponding to the cond_stage_key # The returned dict will have keys that correspond to the cond_stage_key ret_dict = {} if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: self.load_pretrain_model() self.is_reload = True self.check_module_param_update() if self.always_output_audiomae_gt or ( self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob ): cond_dict = self.get_input(batch) gt_audiomae = self.prenet(cond_dict["crossattn_audiomae_pooled"][0]) ret_dict["crossattn_audiomae_generated"] = [ gt_audiomae, torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), ] # Input sequence and mask else: print("--------------> Generate!!!!!!!!!!!!") input_embeds, cond_dict = self.generate(batch) # input_embeds, cond_dict = self.generate_partial(batch) input_embeds = self.prenet(input_embeds) input_embeds_mask = ( torch.ones((input_embeds.size(0), input_embeds.size(1))) .to(input_embeds.device) .float() ) ret_dict["crossattn_audiomae_generated"] = [ input_embeds, input_embeds_mask, ] # Input sequence and mask # If the following two keys are not in cond_stage_key, then they will not be used as condition for key in cond_dict.keys(): ret_dict[key] = cond_dict[key] if self.learnable and self.training: loss = self.training_step(batch, cond_dict=cond_dict) ret_dict["noncond_loss_clap2audiomae"] = loss return ret_dict class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): """ audiomae = AudioMAEConditionCTPool2x2() data = torch.randn((4, 1024, 128)) output = audiomae(data) import ipdb;ipdb.set_trace() exit(0) """ def __init__( self, time_pooling_factors=[1, 2, 4, 8], freq_pooling_factors=[1, 2, 4, 8], eval_time_pooling=None, eval_freq_pooling=None, mask_ratio=0.0, regularization=False, no_audiomae_mask=True, no_audiomae_average=False, ): super().__init__() self.device = None self.time_pooling_factors = time_pooling_factors self.freq_pooling_factors = freq_pooling_factors self.no_audiomae_mask = no_audiomae_mask self.no_audiomae_average = no_audiomae_average self.eval_freq_pooling = eval_freq_pooling self.eval_time_pooling = eval_time_pooling self.mask_ratio = mask_ratio self.use_reg = regularization self.audiomae = Vanilla_AudioMAE() self.audiomae.eval() for p in self.audiomae.parameters(): p.requires_grad = False # Required def get_unconditional_condition(self, batchsize): param = next(self.audiomae.parameters()) assert param.requires_grad == False device = param.device # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] token_num = int(512 / (time_pool * freq_pool)) return [ torch.zeros((batchsize, token_num, 768)).to(device).float(), torch.ones((batchsize, token_num)).to(device).float(), ] def pool(self, representation, time_pool=None, freq_pool=None): assert representation.size(-1) == 768 representation = representation[:, 1:, :].transpose(1, 2) bs, embedding_dim, token_num = representation.size() representation = representation.reshape(bs, embedding_dim, 64, 8) if self.training: if time_pool is None and freq_pool is None: time_pool = min( 64, self.time_pooling_factors[ np.random.choice(list(range(len(self.time_pooling_factors)))) ], ) freq_pool = min( 8, self.freq_pooling_factors[ np.random.choice(list(range(len(self.freq_pooling_factors)))) ], ) # freq_pool = min(8, time_pool) # TODO here I make some modification. else: time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) self.avgpooling = nn.AvgPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) self.maxpooling = nn.MaxPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) pooled = ( self.avgpooling(representation) + self.maxpooling(representation) ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] pooled = pooled.flatten(2).transpose(1, 2) return pooled # [bs, token_num, embedding_dim] def regularization(self, x): assert x.size(-1) == 768 x = F.normalize(x, p=2, dim=-1) return x # Required def forward(self, batch, time_pool=None, freq_pool=None): assert batch.size(-2) == 1024 and batch.size(-1) == 128 if self.device is None: self.device = batch.device batch = batch.unsqueeze(1) with torch.no_grad(): representation = self.audiomae( batch, mask_ratio=self.mask_ratio, no_mask=self.no_audiomae_mask, no_average=self.no_audiomae_average, ) representation = self.pool(representation, time_pool, freq_pool) if self.use_reg: representation = self.regularization(representation) return [ representation, torch.ones((representation.size(0), representation.size(1))) .to(representation.device) .float(), ] class AudioMAEConditionCTPoolRand(nn.Module): """ audiomae = AudioMAEConditionCTPool2x2() data = torch.randn((4, 1024, 128)) output = audiomae(data) import ipdb;ipdb.set_trace() exit(0) """ def __init__( self, time_pooling_factors=[1, 2, 4, 8], freq_pooling_factors=[1, 2, 4, 8], eval_time_pooling=None, eval_freq_pooling=None, mask_ratio=0.0, regularization=False, no_audiomae_mask=True, no_audiomae_average=False, ): super().__init__() self.device = None self.time_pooling_factors = time_pooling_factors self.freq_pooling_factors = freq_pooling_factors self.no_audiomae_mask = no_audiomae_mask self.no_audiomae_average = no_audiomae_average self.eval_freq_pooling = eval_freq_pooling self.eval_time_pooling = eval_time_pooling self.mask_ratio = mask_ratio self.use_reg = regularization self.audiomae = Vanilla_AudioMAE() self.audiomae.eval() for p in self.audiomae.parameters(): p.requires_grad = False # Required def get_unconditional_condition(self, batchsize): param = next(self.audiomae.parameters()) assert param.requires_grad == False device = param.device # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] token_num = int(512 / (time_pool * freq_pool)) return [ torch.zeros((batchsize, token_num, 768)).to(device).float(), torch.ones((batchsize, token_num)).to(device).float(), ] def pool(self, representation, time_pool=None, freq_pool=None): assert representation.size(-1) == 768 representation = representation[:, 1:, :].transpose(1, 2) bs, embedding_dim, token_num = representation.size() representation = representation.reshape(bs, embedding_dim, 64, 8) if self.training: if time_pool is None and freq_pool is None: time_pool = min( 64, self.time_pooling_factors[ np.random.choice(list(range(len(self.time_pooling_factors)))) ], ) # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] freq_pool = min(8, time_pool) # TODO here I make some modification. else: time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) self.avgpooling = nn.AvgPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) self.maxpooling = nn.MaxPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) pooled = ( self.avgpooling(representation) + self.maxpooling(representation) ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] pooled = pooled.flatten(2).transpose(1, 2) return pooled # [bs, token_num, embedding_dim] def regularization(self, x): assert x.size(-1) == 768 x = F.normalize(x, p=2, dim=-1) return x # Required def forward(self, batch, time_pool=None, freq_pool=None): assert batch.size(-2) == 1024 and batch.size(-1) == 128 if self.device is None: self.device = batch.device batch = batch.unsqueeze(1) with torch.no_grad(): representation = self.audiomae( batch, mask_ratio=self.mask_ratio, no_mask=self.no_audiomae_mask, no_average=self.no_audiomae_average, ) representation = self.pool(representation, time_pool, freq_pool) if self.use_reg: representation = self.regularization(representation) return [ representation, torch.ones((representation.size(0), representation.size(1))) .to(representation.device) .float(), ] class ConditionalToken(nn.Module): def __init__(self, embedding_dim): super(ConditionalToken, self).__init__() self.embedding_dim = embedding_dim # Define the conditional tokens as fixed values self.pooling_factor_tokens = { 1: torch.Tensor([1.0, 0.0] * (embedding_dim // 2)), 2: torch.Tensor([0.0, 1.0] * (embedding_dim // 2)), 4: torch.Tensor([1.0, 1.0] * (embedding_dim // 2)), 8: torch.Tensor([-1.0, 0.0] * (embedding_dim // 2)), 16: torch.Tensor([0.0, -1.0] * (embedding_dim // 2)), 32: torch.Tensor([-1.0, -1.0] * (embedding_dim // 2)), 64: torch.Tensor([0.0, 0.0] * (embedding_dim // 2)), } for p in self.parameters(): p.requires_grad = False def forward(self, condition, batchsize): """ Returns the conditional token for the given condition. """ if condition not in self.pooling_factor_tokens.keys(): raise ValueError(f"Unsupported condition: {condition}") batched_token = self.pooling_factor_tokens[condition][None, None].expand( batchsize, 1, self.embedding_dim ) return batched_token class AudioMAEConditionCTPoolRandV2(nn.Module): """ audiomae = AudioMAEConditionCTPool2x2() data = torch.randn((4, 1024, 128)) output = audiomae(data) import ipdb;ipdb.set_trace() exit(0) """ def __init__( self, time_pooling_factors=[1, 2, 4, 8], freq_pooling_factors=[1, 2, 4, 8], eval_time_pooling=None, eval_freq_pooling=None, mask_ratio=0.0, regularization=False, no_audiomae_mask=True, no_audiomae_average=False, ): super().__init__() self.device = None self.time_pooling_factors = time_pooling_factors self.freq_pooling_factors = freq_pooling_factors self.no_audiomae_mask = no_audiomae_mask self.no_audiomae_average = no_audiomae_average self.eval_freq_pooling = eval_freq_pooling self.eval_time_pooling = eval_time_pooling self.mask_ratio = mask_ratio self.use_reg = regularization self.pooling_tokens = ConditionalToken(768) self.audiomae = Vanilla_AudioMAE() self.audiomae.eval() for p in self.audiomae.parameters(): p.requires_grad = False # Required def get_unconditional_condition(self, batchsize): param = next(self.audiomae.parameters()) assert param.requires_grad == False device = param.device # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] pool_condition_token = self.pooling_tokens(time_pool, batchsize).to(device) token_num = int(512 / (time_pool * freq_pool)) rep = torch.zeros((batchsize, token_num, 768)).to(device).float() rep = torch.cat([rep, pool_condition_token], dim=1) return [rep, torch.ones((batchsize, token_num + 1)).to(device).float()] def pool(self, representation, time_pool=None, freq_pool=None): assert representation.size(-1) == 768 representation = representation[:, 1:, :].transpose(1, 2) bs, embedding_dim, token_num = representation.size() representation = representation.reshape(bs, embedding_dim, 64, 8) if self.training: if time_pool is None and freq_pool is None: time_pool = min( 64, self.time_pooling_factors[ np.random.choice(list(range(len(self.time_pooling_factors)))) ], ) # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] freq_pool = min(8, time_pool) # TODO here I make some modification. else: time_pool, freq_pool = min(self.eval_time_pooling, 64), min( self.eval_freq_pooling, 8 ) self.avgpooling = nn.AvgPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) self.maxpooling = nn.MaxPool2d( kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) ) pooled = ( self.avgpooling(representation) + self.maxpooling(representation) ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] pooled = pooled.flatten(2).transpose(1, 2) return pooled, time_pool, freq_pool # [bs, token_num, embedding_dim] def regularization(self, x): assert x.size(-1) == 768 x = F.normalize(x, p=2, dim=-1) return x # Required def forward(self, batch): assert batch.size(-2) == 1024 and batch.size(-1) == 128 if self.device is None: self.device = batch.device batch = batch.unsqueeze(1) with torch.no_grad(): representation = self.audiomae( batch, mask_ratio=self.mask_ratio, no_mask=self.no_audiomae_mask, no_average=self.no_audiomae_average, ) representation, time_pool, freq_pool = self.pool(representation) if self.use_reg: representation = self.regularization(representation) pool_condition_token = self.pooling_tokens( time_pool, representation.size(0) ).to(representation.device) representation = torch.cat([representation, pool_condition_token], dim=1) return [ representation, torch.ones((representation.size(0), representation.size(1))) .to(representation.device) .float(), ] class BeatDownbeatConditionConcat(nn.Module): def __init__(self, latent_t_size, latent_f_size): super().__init__() self.latent_t_size = latent_t_size self.latent_f_size = latent_f_size self.device = None # Required def get_unconditional_condition(self, batchsize): return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( self.device ) # Required def forward(self, batch): if self.device is None: self.device = batch.device return batch class CLAPAudioEmbeddingClassifierFreev2(nn.Module): def __init__( self, pretrained_path, sampling_rate=16000, embed_mode="audio", amodel="HTSAT-base", unconditional_prob=0.1, random_mute=False, max_random_mute_portion=0.5, training_mode=True, ): super().__init__() self.device = "cpu" self.precision = "fp32" self.amodel = amodel # or 'PANN-14' self.tmodel = "roberta" # the best text encoder in our training self.enable_fusion = False # False if you do not want to use the fusion model self.fusion_type = "aff_2d" self.pretrained = pretrained_path self.embed_mode = embed_mode self.embed_mode_orig = embed_mode self.sampling_rate = sampling_rate self.unconditional_prob = unconditional_prob self.random_mute = random_mute self.tokenize = RobertaTokenizer.from_pretrained(config_data["roberta-base"]) self.max_random_mute_portion = max_random_mute_portion self.training_mode = training_mode self.model, self.model_cfg = create_model( self.amodel, self.tmodel, self.pretrained, precision=self.precision, device=self.device, enable_fusion=self.enable_fusion, fusion_type=self.fusion_type, ) audio_cfg = self.model_cfg["audio_cfg"] self.mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=audio_cfg["sample_rate"], n_fft=audio_cfg["window_size"], win_length=audio_cfg["window_size"], hop_length=audio_cfg["hop_size"], center=True, pad_mode="reflect", power=2.0, norm=None, onesided=True, n_mels=64, f_min=audio_cfg["fmin"], f_max=audio_cfg["fmax"], ) for p in self.model.parameters(): p.requires_grad = False self.unconditional_token = None self.model.eval() def get_unconditional_condition(self, batchsize): self.unconditional_token = self.model.get_text_embedding( self.tokenizer(["", ""]) )[0:1] return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) def batch_to_list(self, batch): ret = [] for i in range(batch.size(0)): ret.append(batch[i]) return ret def make_decision(self, probability): if float(torch.rand(1)) < probability: return True else: return False def random_uniform(self, start, end): val = torch.rand(1).item() return start + (end - start) * val def _random_mute(self, waveform): # waveform: [bs, t-steps] t_steps = waveform.size(-1) for i in range(waveform.size(0)): mute_size = int( self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) ) mute_start = int(self.random_uniform(0, t_steps - mute_size)) waveform[i, mute_start : mute_start + mute_size] = 0 return waveform def cos_similarity(self, waveform, text): # waveform: [bs, t_steps] original_embed_mode = self.embed_mode with torch.no_grad(): self.embed_mode = "audio" audio_emb = self(waveform.cuda()) self.embed_mode = "text" text_emb = self(text) similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) self.embed_mode = original_embed_mode return similarity.squeeze() def build_unconditional_emb(self): self.unconditional_token = self.model.get_text_embedding( self.tokenizer(["", ""]) )[0:1] def forward(self, batch): # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 if self.model.training == True and not self.training_mode: print( "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." ) self.model, self.model_cfg = create_model( self.amodel, self.tmodel, self.pretrained, precision=self.precision, device="cuda", enable_fusion=self.enable_fusion, fusion_type=self.fusion_type, ) for p in self.model.parameters(): p.requires_grad = False self.model.eval() if self.unconditional_token is None: self.build_unconditional_emb() # if(self.training_mode): # assert self.model.training == True # else: # assert self.model.training == False # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode if self.embed_mode == "audio": if not self.training: print("INFO: clap model calculate the audio embedding as condition") with torch.no_grad(): # assert ( # self.sampling_rate == 16000 # ), "We only support 16000 sampling rate" # if self.random_mute: # batch = self._random_mute(batch) # batch: [bs, 1, t-samples] if self.sampling_rate != 48000: batch = torchaudio.functional.resample( batch, orig_freq=self.sampling_rate, new_freq=48000 ) audio_data = batch.squeeze(1) mel = self.mel_transform(audio_data) audio_dict = get_audio_features( audio_data, mel, 480000, data_truncating="fusion", data_filling="repeatpad", audio_cfg=self.model_cfg["audio_cfg"], ) # [bs, 512] embed = self.model.get_audio_embedding(audio_dict) elif self.embed_mode == "text": with torch.no_grad(): # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode text_data = self.tokenizer(batch) if isinstance(batch, str) or ( isinstance(batch, list) and len(batch) == 1 ): for key in text_data.keys(): text_data[key] = text_data[key].unsqueeze(0) embed = self.model.get_text_embedding(text_data) embed = embed.unsqueeze(1) for i in range(embed.size(0)): if self.make_decision(self.unconditional_prob): embed[i] = self.unconditional_token # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) return embed.detach() def tokenizer(self, text): result = self.tokenize( text, padding="max_length", truncation=True, max_length=512, return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()} if __name__ == "__main__": model = CLAPAudioEmbeddingClassifierFreev2( pretrained_path="/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/ckpt/CLAP.pt", embed_mode="text", amodel="HTSAT-tiny", ) # data = torch.randn((6, 1, int(16000*10.24))) data = ["text", "text"] res = model(data) import ipdb ipdb.set_trace()