|
import sys |
|
|
|
sys.path.append("src") |
|
import torch |
|
import logging |
|
import torch.nn as nn |
|
from qa_mdt.audioldm_train.modules.clap.open_clip import create_model |
|
from qa_mdt.audioldm_train.modules.clap.training.data import get_audio_features |
|
|
|
import torchaudio |
|
from transformers import ( |
|
RobertaTokenizer, |
|
AutoTokenizer, |
|
T5EncoderModel, |
|
MT5EncoderModel, |
|
) |
|
import torch.nn.functional as F |
|
from qa_mdt.audioldm_train.modules.audiomae.AudioMAE import Vanilla_AudioMAE |
|
from qa_mdt.audioldm_train.modules.phoneme_encoder.encoder import TextEncoder |
|
|
|
from transformers import SpeechT5Processor, AutoTokenizer, GPT2Model, GPT2Tokenizer |
|
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithTextPrenet |
|
|
|
from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import CLAP2AudioMAE |
|
from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.sequence_input import ( |
|
Sequence2AudioMAE, |
|
) |
|
import numpy as np |
|
from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import Prenet |
|
import json |
|
with open('./qa_mdt/offset_pretrained_checkpoints.json', 'r') as config_file: |
|
config_data = json.load(config_file) |
|
|
|
""" |
|
The model forward function can return three types of data: |
|
1. tensor: used directly as conditioning signal |
|
2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. |
|
3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. |
|
|
|
The output shape for the cross attention condition should be: |
|
x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] |
|
|
|
All the returned data, in which will be used as diffusion input, will need to be in float type |
|
""" |
|
|
|
|
|
class GPT2WordEmbedding(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.model = GPT2Model.from_pretrained("gpt2").wte |
|
self.device = None |
|
|
|
def get_unconditional_condition(self, batchsize): |
|
unconditional_condition = ["random"] * batchsize |
|
return self(unconditional_condition) |
|
|
|
def forward(self, text): |
|
assert isinstance(text, list) |
|
if self.device is None: |
|
self.device = next(self.model.parameters()).device |
|
|
|
tokenization_result = self.tokenizer(text, return_tensors="pt", padding=True) |
|
input_ids, attn_mask = tokenization_result["input_ids"].to( |
|
self.device |
|
), tokenization_result["attention_mask"].to(self.device) |
|
|
|
input_embed = self.model(input_ids.long()) |
|
|
|
return [input_embed, attn_mask] |
|
|
|
|
|
class ConcateBandWidthCond(nn.Module): |
|
def __init__(self, latent_t_size, latent_f_size): |
|
super().__init__() |
|
self.placeholder = nn.Linear(1, 1) |
|
self.latent_t_size = latent_t_size |
|
self.latent_f_size = latent_f_size |
|
self.device = None |
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( |
|
self.device |
|
) |
|
|
|
def forward(self, mel_spec_bandwidth_cond_extra_channel): |
|
if self.device is None: |
|
self.device = mel_spec_bandwidth_cond_extra_channel.device |
|
|
|
return mel_spec_bandwidth_cond_extra_channel |
|
|
|
|
|
class BandwidthEncoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.emb = nn.Embedding(1000, 128) |
|
nn.init.normal_(self.emb.weight, 0.0, 128**-0.5) |
|
self.linear_bandwidth = nn.Linear(128, 128) |
|
self.unconditional_condition = torch.zeros((1, 256)) |
|
self.device = None |
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return self.unconditional_condition.expand(batchsize, 256) |
|
|
|
def forward(self, bandwidth): |
|
|
|
if self.device is None: |
|
self.device = next(self.linear_bandwidth.parameters()).device |
|
self.unconditional_condition = self.unconditional_condition.to(self.device) |
|
|
|
|
|
lower_cutoff, higher_cutoff = bandwidth[..., 0], bandwidth[..., 1] |
|
|
|
|
|
lower_cutoff_emb = self.linear_bandwidth(self.emb(lower_cutoff.long())) |
|
higher_cutoff_emb = self.linear_bandwidth(self.emb(higher_cutoff.long())) |
|
cutoff_emb = torch.cat([lower_cutoff_emb, higher_cutoff_emb], dim=-1) |
|
|
|
return cutoff_emb |
|
|
|
|
|
class SpeechT5TextEncoder(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") |
|
self.model = SpeechT5EncoderWithTextPrenet.from_pretrained( |
|
"microsoft/speecht5_tts" |
|
) |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.model.eval() |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
device = self.model.device |
|
hidden_state = torch.zeros((batchsize, 1, 768)).to(device) |
|
attention_mask = torch.ones((batchsize, 1)).to(device) |
|
return [hidden_state.float(), attention_mask.float()] |
|
|
|
def forward(self, text): |
|
with torch.no_grad(): |
|
device = self.model.device |
|
inputs = self.processor(text=text, return_tensors="pt", padding=True) |
|
input_ids, attention_mask = inputs["input_ids"].to(device), inputs[ |
|
"attention_mask" |
|
].to(device) |
|
emb = self.model(input_ids, attention_mask) |
|
emb = emb.last_hidden_state.detach() |
|
return [emb.float(), attention_mask.float()] |
|
|
|
|
|
class PhonemeEncoder(nn.Module): |
|
def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): |
|
super().__init__() |
|
""" |
|
encoder = PhonemeEncoder(40) |
|
data = torch.randint(0, 39, (2, 250)) |
|
output = encoder(data) |
|
import ipdb;ipdb.set_trace() |
|
""" |
|
assert pad_token_id is not None |
|
|
|
self.device = None |
|
self.PAD_LENGTH = int(pad_length) |
|
self.pad_token_id = pad_token_id |
|
self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) |
|
|
|
self.text_encoder = TextEncoder( |
|
n_vocab=vocabs_size, |
|
out_channels=192, |
|
hidden_channels=192, |
|
filter_channels=768, |
|
n_heads=2, |
|
n_layers=6, |
|
kernel_size=3, |
|
p_dropout=0.1, |
|
) |
|
|
|
self.learnable_positional_embedding = torch.nn.Parameter( |
|
torch.zeros((1, 192, self.PAD_LENGTH)) |
|
) |
|
self.learnable_positional_embedding.requires_grad = True |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
unconditional_tokens = self.pad_token_sequence.expand( |
|
batchsize, self.PAD_LENGTH |
|
) |
|
return self(unconditional_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_src_mask(self, phoneme): |
|
src_mask = phoneme != self.pad_token_id |
|
return src_mask |
|
|
|
def _get_src_length(self, phoneme): |
|
src_mask = self._get_src_mask(phoneme) |
|
length = torch.sum(src_mask, dim=-1) |
|
return length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, phoneme_idx): |
|
if self.device is None: |
|
self.device = self.learnable_positional_embedding.device |
|
self.pad_token_sequence = self.pad_token_sequence.to(self.device) |
|
|
|
src_length = self._get_src_length(phoneme_idx) |
|
text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) |
|
text_emb = text_emb + self.learnable_positional_embedding |
|
|
|
|
|
|
|
return [ |
|
text_emb.permute(0, 2, 1), |
|
text_emb_mask.squeeze(1), |
|
] |
|
|
|
|
|
class FlanT5HiddenState(nn.Module): |
|
""" |
|
llama = FlanT5HiddenState() |
|
data = ["","this is not an empty sentence"] |
|
encoder_hidden_states = llama(data) |
|
import ipdb;ipdb.set_trace() |
|
""" |
|
|
|
def __init__( |
|
self, text_encoder_name=config_data['flan_t5'], freeze_text_encoder=True |
|
): |
|
super().__init__() |
|
self.freeze_text_encoder = freeze_text_encoder |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") |
|
self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") |
|
if freeze_text_encoder: |
|
self.model.eval() |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
else: |
|
print("=> The text encoder is learnable") |
|
|
|
self.empty_hidden_state_cfg = None |
|
self.device = None |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
param = next(self.model.parameters()) |
|
if self.freeze_text_encoder: |
|
assert param.requires_grad == False |
|
|
|
|
|
if self.empty_hidden_state_cfg is None: |
|
self.empty_hidden_state_cfg, _ = self([""]) |
|
|
|
hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() |
|
attention_mask = ( |
|
torch.ones((batchsize, hidden_state.size(1))) |
|
.to(hidden_state.device) |
|
.float() |
|
) |
|
return [hidden_state, attention_mask] |
|
|
|
def forward(self, batch): |
|
param = next(self.model.parameters()) |
|
if self.freeze_text_encoder: |
|
assert param.requires_grad == False |
|
|
|
if self.device is None: |
|
self.device = param.device |
|
|
|
|
|
|
|
|
|
try: |
|
return self.encode_text(batch) |
|
except Exception as e: |
|
print(e, batch) |
|
logging.exception("An error occurred: %s", str(e)) |
|
|
|
def encode_text(self, prompt): |
|
device = self.model.device |
|
batch = self.tokenizer( |
|
prompt, |
|
max_length=128, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( |
|
device |
|
) |
|
|
|
if self.freeze_text_encoder: |
|
with torch.no_grad(): |
|
encoder_hidden_states = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
)[0] |
|
else: |
|
encoder_hidden_states = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
)[0] |
|
return [ |
|
encoder_hidden_states.detach(), |
|
attention_mask.float(), |
|
] |
|
|
|
|
|
class FlanT5HiddenStatePaddedSameLength(nn.Module): |
|
""" |
|
llama = FlanT5HiddenState() |
|
data = ["","this is not an empty sentence"] |
|
encoder_hidden_states = llama(data) |
|
import ipdb;ipdb.set_trace() |
|
""" |
|
|
|
def __init__( |
|
self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True |
|
): |
|
super().__init__() |
|
self.freeze_text_encoder = freeze_text_encoder |
|
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") |
|
self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") |
|
if freeze_text_encoder: |
|
self.model.eval() |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
else: |
|
print("=> The text encoder is learnable") |
|
|
|
self.empty_hidden_state_cfg = None |
|
self.device = None |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
param = next(self.model.parameters()) |
|
if self.freeze_text_encoder: |
|
assert param.requires_grad == False |
|
|
|
|
|
if self.empty_hidden_state_cfg is None: |
|
self.empty_hidden_state_cfg, _ = self([""]) |
|
|
|
hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() |
|
attention_mask = ( |
|
torch.ones((batchsize, hidden_state.size(1))) |
|
.to(hidden_state.device) |
|
.float() |
|
) |
|
return [hidden_state, attention_mask] |
|
|
|
def forward(self, batch): |
|
param = next(self.model.parameters()) |
|
if self.freeze_text_encoder: |
|
assert param.requires_grad == False |
|
|
|
if self.device is None: |
|
self.device = param.device |
|
|
|
|
|
|
|
|
|
try: |
|
text_embed = self.encode_text(batch) |
|
return text_embed |
|
except Exception as e: |
|
print(e, batch) |
|
logging.exception("An error occurred: %s", str(e)) |
|
|
|
def encode_text(self, prompt): |
|
device = self.model.device |
|
batch = self.tokenizer( |
|
prompt, |
|
max_length=128, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( |
|
device |
|
) |
|
|
|
|
|
if self.freeze_text_encoder: |
|
with torch.no_grad(): |
|
encoder_hidden_states = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
)[0] |
|
else: |
|
encoder_hidden_states = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask |
|
)[0] |
|
return [ |
|
encoder_hidden_states.detach(), |
|
attention_mask.float(), |
|
] |
|
|
|
|
|
class CLAPGenAudioMAECond(CLAP2AudioMAE): |
|
def __init__( |
|
self, |
|
cond_stage_config, |
|
learnable=True, |
|
pretrained_path=None, |
|
use_gt_mae_output=None, |
|
use_gt_mae_prob=None, |
|
): |
|
super().__init__(base_learning_rate=1e-5, cond_stage_config=cond_stage_config) |
|
assert use_gt_mae_output is not None and use_gt_mae_prob is not None |
|
|
|
if pretrained_path is not None: |
|
print("Reload CLAPGenAudioMAECond from %s" % pretrained_path) |
|
state_dict = torch.load(pretrained_path)["state_dict"] |
|
self.load_state_dict(state_dict) |
|
|
|
self.use_gt_mae_output = use_gt_mae_output |
|
self.use_gt_mae_prob = use_gt_mae_prob |
|
self.learnable = learnable |
|
|
|
if not learnable: |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.eval() |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return_dict = self.cfg_uncond(batchsize) |
|
return return_dict |
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
ret_dict = {} |
|
if self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob: |
|
cond_dict = self.get_input(batch) |
|
|
|
ret_dict["crossattn_clap_to_audiomae_feature"] = [ |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), |
|
] |
|
else: |
|
|
|
input_embeds, cond_dict = self.generate(batch) |
|
input_embeds_mask = ( |
|
torch.ones((input_embeds.size(0), input_embeds.size(1))) |
|
.to(input_embeds.device) |
|
.float() |
|
) |
|
ret_dict["crossattn_clap_to_audiomae_feature"] = [ |
|
input_embeds, |
|
input_embeds_mask, |
|
] |
|
|
|
|
|
ret_dict["film_clap_cond1"] = cond_dict[ |
|
"film_clap_cond1" |
|
] |
|
ret_dict["crossattn_audiomae_pooled"] = cond_dict[ |
|
"crossattn_audiomae_pooled" |
|
] |
|
|
|
if self.learnable and self.training: |
|
loss = self.training_step(batch, cond_dict=cond_dict) |
|
ret_dict["noncond_loss_clap2audiomae"] = loss |
|
|
|
return ret_dict |
|
|
|
|
|
class SequenceGenAudioMAECond(Sequence2AudioMAE): |
|
def __init__( |
|
self, |
|
cond_stage_config, |
|
base_learning_rate, |
|
sequence_gen_length, |
|
sequence_input_key, |
|
sequence_input_embed_dim, |
|
batchsize, |
|
always_output_audiomae_gt=False, |
|
pretrained_path=None, |
|
force_reload_pretrain_avoid_overwrite=False, |
|
learnable=True, |
|
use_warmup=True, |
|
use_gt_mae_output=None, |
|
use_gt_mae_prob=None, |
|
): |
|
if use_warmup: |
|
print( |
|
"Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." |
|
) |
|
use_warmup = False |
|
|
|
super().__init__( |
|
base_learning_rate=base_learning_rate, |
|
cond_stage_config=cond_stage_config, |
|
sequence_gen_length=sequence_gen_length, |
|
sequence_input_key=sequence_input_key, |
|
use_warmup=use_warmup, |
|
sequence_input_embed_dim=sequence_input_embed_dim, |
|
batchsize=batchsize, |
|
) |
|
|
|
assert use_gt_mae_output is not None and use_gt_mae_prob is not None |
|
self.always_output_audiomae_gt = always_output_audiomae_gt |
|
self.force_reload_pretrain_avoid_overwrite = ( |
|
force_reload_pretrain_avoid_overwrite |
|
) |
|
self.pretrained_path = pretrained_path |
|
if self.force_reload_pretrain_avoid_overwrite: |
|
self.is_reload = False |
|
else: |
|
self.is_reload = True |
|
|
|
self.load_pretrain_model() |
|
|
|
self.use_gt_mae_output = use_gt_mae_output |
|
self.use_gt_mae_prob = use_gt_mae_prob |
|
self.learnable = learnable |
|
|
|
if not learnable: |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.eval() |
|
|
|
def load_pretrain_model(self): |
|
if self.pretrained_path is not None: |
|
print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) |
|
state_dict = torch.load(self.pretrained_path)["state_dict"] |
|
self.load_state_dict(state_dict) |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return_dict = self.cfg_uncond(batchsize) |
|
return_dict["crossattn_audiomae_generated"] = [ |
|
return_dict["crossattn_audiomae_pooled"][0], |
|
torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), |
|
] |
|
return return_dict |
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
ret_dict = {} |
|
|
|
if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: |
|
self.load_pretrain_model() |
|
self.is_reload = True |
|
|
|
self.check_module_param_update() |
|
|
|
if self.always_output_audiomae_gt or ( |
|
self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob |
|
): |
|
cond_dict = self.get_input(batch) |
|
ret_dict["crossattn_audiomae_generated"] = [ |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), |
|
] |
|
|
|
|
|
else: |
|
if not self.training: |
|
print("--------------> Generate !!!!!!!!!!!!") |
|
input_embeds, cond_dict = self.generate(batch) |
|
|
|
input_embeds_mask = ( |
|
torch.ones((input_embeds.size(0), input_embeds.size(1))) |
|
.to(input_embeds.device) |
|
.float() |
|
) |
|
ret_dict["crossattn_audiomae_generated"] = [ |
|
input_embeds, |
|
input_embeds_mask, |
|
] |
|
|
|
|
|
for key in cond_dict.keys(): |
|
ret_dict[key] = cond_dict[key] |
|
|
|
if self.learnable and self.training: |
|
loss = self.training_step(batch, cond_dict=cond_dict) |
|
ret_dict["noncond_loss_clap2audiomae"] = loss |
|
|
|
return ret_dict |
|
|
|
|
|
class SequenceGenAudioMAECond_AudioMAE_PostNet(Sequence2AudioMAE): |
|
def __init__( |
|
self, |
|
cond_stage_config, |
|
base_learning_rate, |
|
sequence_gen_length, |
|
sequence_input_key, |
|
sequence_input_embed_dim, |
|
batchsize, |
|
always_output_audiomae_gt=False, |
|
pretrained_path=None, |
|
use_ar_gen_loss=False, |
|
force_reload_pretrain_avoid_overwrite=False, |
|
learnable=True, |
|
use_warmup=True, |
|
use_gt_mae_output=None, |
|
use_gt_mae_prob=None, |
|
): |
|
if use_warmup: |
|
print( |
|
"Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." |
|
) |
|
use_warmup = False |
|
|
|
super().__init__( |
|
base_learning_rate=base_learning_rate, |
|
cond_stage_config=cond_stage_config, |
|
sequence_gen_length=sequence_gen_length, |
|
sequence_input_key=sequence_input_key, |
|
use_ar_gen_loss=use_ar_gen_loss, |
|
use_warmup=use_warmup, |
|
sequence_input_embed_dim=sequence_input_embed_dim, |
|
batchsize=batchsize, |
|
) |
|
|
|
assert use_gt_mae_output is not None and use_gt_mae_prob is not None |
|
self.always_output_audiomae_gt = always_output_audiomae_gt |
|
self.force_reload_pretrain_avoid_overwrite = ( |
|
force_reload_pretrain_avoid_overwrite |
|
) |
|
self.pretrained_path = pretrained_path |
|
if self.force_reload_pretrain_avoid_overwrite: |
|
self.is_reload = False |
|
else: |
|
self.is_reload = True |
|
|
|
self.load_pretrain_model() |
|
|
|
self.prenet = Prenet(in_dim=768, sizes=[768, 768, 768], dropout_rate=0.5) |
|
|
|
self.use_gt_mae_output = use_gt_mae_output |
|
self.use_gt_mae_prob = use_gt_mae_prob |
|
self.learnable = learnable |
|
|
|
if not learnable: |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.eval() |
|
|
|
def load_pretrain_model(self): |
|
if self.pretrained_path is not None: |
|
print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) |
|
state_dict = torch.load(self.pretrained_path)["state_dict"] |
|
self.load_state_dict(state_dict) |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return_dict = self.cfg_uncond(batchsize) |
|
return_dict["crossattn_audiomae_generated"] = [ |
|
return_dict["crossattn_audiomae_pooled"][0], |
|
torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), |
|
] |
|
return return_dict |
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
ret_dict = {} |
|
|
|
if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: |
|
self.load_pretrain_model() |
|
self.is_reload = True |
|
|
|
self.check_module_param_update() |
|
|
|
if self.always_output_audiomae_gt or ( |
|
self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob |
|
): |
|
cond_dict = self.get_input(batch) |
|
gt_audiomae = self.prenet(cond_dict["crossattn_audiomae_pooled"][0]) |
|
ret_dict["crossattn_audiomae_generated"] = [ |
|
gt_audiomae, |
|
torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), |
|
] |
|
else: |
|
print("--------------> Generate!!!!!!!!!!!!") |
|
input_embeds, cond_dict = self.generate(batch) |
|
|
|
input_embeds = self.prenet(input_embeds) |
|
input_embeds_mask = ( |
|
torch.ones((input_embeds.size(0), input_embeds.size(1))) |
|
.to(input_embeds.device) |
|
.float() |
|
) |
|
ret_dict["crossattn_audiomae_generated"] = [ |
|
input_embeds, |
|
input_embeds_mask, |
|
] |
|
|
|
|
|
for key in cond_dict.keys(): |
|
ret_dict[key] = cond_dict[key] |
|
|
|
if self.learnable and self.training: |
|
loss = self.training_step(batch, cond_dict=cond_dict) |
|
ret_dict["noncond_loss_clap2audiomae"] = loss |
|
|
|
return ret_dict |
|
|
|
|
|
class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): |
|
""" |
|
audiomae = AudioMAEConditionCTPool2x2() |
|
data = torch.randn((4, 1024, 128)) |
|
output = audiomae(data) |
|
import ipdb;ipdb.set_trace() |
|
exit(0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
time_pooling_factors=[1, 2, 4, 8], |
|
freq_pooling_factors=[1, 2, 4, 8], |
|
eval_time_pooling=None, |
|
eval_freq_pooling=None, |
|
mask_ratio=0.0, |
|
regularization=False, |
|
no_audiomae_mask=True, |
|
no_audiomae_average=False, |
|
): |
|
super().__init__() |
|
self.device = None |
|
self.time_pooling_factors = time_pooling_factors |
|
self.freq_pooling_factors = freq_pooling_factors |
|
self.no_audiomae_mask = no_audiomae_mask |
|
self.no_audiomae_average = no_audiomae_average |
|
|
|
self.eval_freq_pooling = eval_freq_pooling |
|
self.eval_time_pooling = eval_time_pooling |
|
self.mask_ratio = mask_ratio |
|
self.use_reg = regularization |
|
|
|
self.audiomae = Vanilla_AudioMAE() |
|
self.audiomae.eval() |
|
for p in self.audiomae.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
param = next(self.audiomae.parameters()) |
|
assert param.requires_grad == False |
|
device = param.device |
|
|
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
|
|
token_num = int(512 / (time_pool * freq_pool)) |
|
return [ |
|
torch.zeros((batchsize, token_num, 768)).to(device).float(), |
|
torch.ones((batchsize, token_num)).to(device).float(), |
|
] |
|
|
|
def pool(self, representation, time_pool=None, freq_pool=None): |
|
assert representation.size(-1) == 768 |
|
representation = representation[:, 1:, :].transpose(1, 2) |
|
bs, embedding_dim, token_num = representation.size() |
|
representation = representation.reshape(bs, embedding_dim, 64, 8) |
|
|
|
if self.training: |
|
if time_pool is None and freq_pool is None: |
|
time_pool = min( |
|
64, |
|
self.time_pooling_factors[ |
|
np.random.choice(list(range(len(self.time_pooling_factors)))) |
|
], |
|
) |
|
freq_pool = min( |
|
8, |
|
self.freq_pooling_factors[ |
|
np.random.choice(list(range(len(self.freq_pooling_factors)))) |
|
], |
|
) |
|
|
|
else: |
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
self.avgpooling = nn.AvgPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
self.maxpooling = nn.MaxPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
|
|
pooled = ( |
|
self.avgpooling(representation) + self.maxpooling(representation) |
|
) / 2 |
|
pooled = pooled.flatten(2).transpose(1, 2) |
|
return pooled |
|
|
|
def regularization(self, x): |
|
assert x.size(-1) == 768 |
|
x = F.normalize(x, p=2, dim=-1) |
|
return x |
|
|
|
|
|
def forward(self, batch, time_pool=None, freq_pool=None): |
|
assert batch.size(-2) == 1024 and batch.size(-1) == 128 |
|
|
|
if self.device is None: |
|
self.device = batch.device |
|
|
|
batch = batch.unsqueeze(1) |
|
with torch.no_grad(): |
|
representation = self.audiomae( |
|
batch, |
|
mask_ratio=self.mask_ratio, |
|
no_mask=self.no_audiomae_mask, |
|
no_average=self.no_audiomae_average, |
|
) |
|
representation = self.pool(representation, time_pool, freq_pool) |
|
if self.use_reg: |
|
representation = self.regularization(representation) |
|
return [ |
|
representation, |
|
torch.ones((representation.size(0), representation.size(1))) |
|
.to(representation.device) |
|
.float(), |
|
] |
|
|
|
|
|
class AudioMAEConditionCTPoolRand(nn.Module): |
|
""" |
|
audiomae = AudioMAEConditionCTPool2x2() |
|
data = torch.randn((4, 1024, 128)) |
|
output = audiomae(data) |
|
import ipdb;ipdb.set_trace() |
|
exit(0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
time_pooling_factors=[1, 2, 4, 8], |
|
freq_pooling_factors=[1, 2, 4, 8], |
|
eval_time_pooling=None, |
|
eval_freq_pooling=None, |
|
mask_ratio=0.0, |
|
regularization=False, |
|
no_audiomae_mask=True, |
|
no_audiomae_average=False, |
|
): |
|
super().__init__() |
|
self.device = None |
|
self.time_pooling_factors = time_pooling_factors |
|
self.freq_pooling_factors = freq_pooling_factors |
|
self.no_audiomae_mask = no_audiomae_mask |
|
self.no_audiomae_average = no_audiomae_average |
|
|
|
self.eval_freq_pooling = eval_freq_pooling |
|
self.eval_time_pooling = eval_time_pooling |
|
self.mask_ratio = mask_ratio |
|
self.use_reg = regularization |
|
|
|
self.audiomae = Vanilla_AudioMAE() |
|
self.audiomae.eval() |
|
for p in self.audiomae.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
param = next(self.audiomae.parameters()) |
|
assert param.requires_grad == False |
|
device = param.device |
|
|
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
|
|
token_num = int(512 / (time_pool * freq_pool)) |
|
return [ |
|
torch.zeros((batchsize, token_num, 768)).to(device).float(), |
|
torch.ones((batchsize, token_num)).to(device).float(), |
|
] |
|
|
|
def pool(self, representation, time_pool=None, freq_pool=None): |
|
assert representation.size(-1) == 768 |
|
representation = representation[:, 1:, :].transpose(1, 2) |
|
bs, embedding_dim, token_num = representation.size() |
|
representation = representation.reshape(bs, embedding_dim, 64, 8) |
|
|
|
if self.training: |
|
if time_pool is None and freq_pool is None: |
|
time_pool = min( |
|
64, |
|
self.time_pooling_factors[ |
|
np.random.choice(list(range(len(self.time_pooling_factors)))) |
|
], |
|
) |
|
|
|
freq_pool = min(8, time_pool) |
|
else: |
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
self.avgpooling = nn.AvgPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
self.maxpooling = nn.MaxPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
|
|
pooled = ( |
|
self.avgpooling(representation) + self.maxpooling(representation) |
|
) / 2 |
|
pooled = pooled.flatten(2).transpose(1, 2) |
|
return pooled |
|
|
|
def regularization(self, x): |
|
assert x.size(-1) == 768 |
|
x = F.normalize(x, p=2, dim=-1) |
|
return x |
|
|
|
|
|
def forward(self, batch, time_pool=None, freq_pool=None): |
|
assert batch.size(-2) == 1024 and batch.size(-1) == 128 |
|
|
|
if self.device is None: |
|
self.device = batch.device |
|
|
|
batch = batch.unsqueeze(1) |
|
with torch.no_grad(): |
|
representation = self.audiomae( |
|
batch, |
|
mask_ratio=self.mask_ratio, |
|
no_mask=self.no_audiomae_mask, |
|
no_average=self.no_audiomae_average, |
|
) |
|
representation = self.pool(representation, time_pool, freq_pool) |
|
if self.use_reg: |
|
representation = self.regularization(representation) |
|
return [ |
|
representation, |
|
torch.ones((representation.size(0), representation.size(1))) |
|
.to(representation.device) |
|
.float(), |
|
] |
|
|
|
|
|
class ConditionalToken(nn.Module): |
|
def __init__(self, embedding_dim): |
|
super(ConditionalToken, self).__init__() |
|
self.embedding_dim = embedding_dim |
|
|
|
self.pooling_factor_tokens = { |
|
1: torch.Tensor([1.0, 0.0] * (embedding_dim // 2)), |
|
2: torch.Tensor([0.0, 1.0] * (embedding_dim // 2)), |
|
4: torch.Tensor([1.0, 1.0] * (embedding_dim // 2)), |
|
8: torch.Tensor([-1.0, 0.0] * (embedding_dim // 2)), |
|
16: torch.Tensor([0.0, -1.0] * (embedding_dim // 2)), |
|
32: torch.Tensor([-1.0, -1.0] * (embedding_dim // 2)), |
|
64: torch.Tensor([0.0, 0.0] * (embedding_dim // 2)), |
|
} |
|
for p in self.parameters(): |
|
p.requires_grad = False |
|
|
|
def forward(self, condition, batchsize): |
|
""" |
|
Returns the conditional token for the given condition. |
|
""" |
|
if condition not in self.pooling_factor_tokens.keys(): |
|
raise ValueError(f"Unsupported condition: {condition}") |
|
batched_token = self.pooling_factor_tokens[condition][None, None].expand( |
|
batchsize, 1, self.embedding_dim |
|
) |
|
return batched_token |
|
|
|
|
|
class AudioMAEConditionCTPoolRandV2(nn.Module): |
|
""" |
|
audiomae = AudioMAEConditionCTPool2x2() |
|
data = torch.randn((4, 1024, 128)) |
|
output = audiomae(data) |
|
import ipdb;ipdb.set_trace() |
|
exit(0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
time_pooling_factors=[1, 2, 4, 8], |
|
freq_pooling_factors=[1, 2, 4, 8], |
|
eval_time_pooling=None, |
|
eval_freq_pooling=None, |
|
mask_ratio=0.0, |
|
regularization=False, |
|
no_audiomae_mask=True, |
|
no_audiomae_average=False, |
|
): |
|
super().__init__() |
|
self.device = None |
|
self.time_pooling_factors = time_pooling_factors |
|
self.freq_pooling_factors = freq_pooling_factors |
|
self.no_audiomae_mask = no_audiomae_mask |
|
self.no_audiomae_average = no_audiomae_average |
|
|
|
self.eval_freq_pooling = eval_freq_pooling |
|
self.eval_time_pooling = eval_time_pooling |
|
self.mask_ratio = mask_ratio |
|
self.use_reg = regularization |
|
|
|
self.pooling_tokens = ConditionalToken(768) |
|
|
|
self.audiomae = Vanilla_AudioMAE() |
|
self.audiomae.eval() |
|
|
|
for p in self.audiomae.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
param = next(self.audiomae.parameters()) |
|
assert param.requires_grad == False |
|
device = param.device |
|
|
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
|
|
pool_condition_token = self.pooling_tokens(time_pool, batchsize).to(device) |
|
token_num = int(512 / (time_pool * freq_pool)) |
|
|
|
rep = torch.zeros((batchsize, token_num, 768)).to(device).float() |
|
rep = torch.cat([rep, pool_condition_token], dim=1) |
|
|
|
return [rep, torch.ones((batchsize, token_num + 1)).to(device).float()] |
|
|
|
def pool(self, representation, time_pool=None, freq_pool=None): |
|
assert representation.size(-1) == 768 |
|
representation = representation[:, 1:, :].transpose(1, 2) |
|
bs, embedding_dim, token_num = representation.size() |
|
representation = representation.reshape(bs, embedding_dim, 64, 8) |
|
|
|
if self.training: |
|
if time_pool is None and freq_pool is None: |
|
time_pool = min( |
|
64, |
|
self.time_pooling_factors[ |
|
np.random.choice(list(range(len(self.time_pooling_factors)))) |
|
], |
|
) |
|
|
|
freq_pool = min(8, time_pool) |
|
else: |
|
time_pool, freq_pool = min(self.eval_time_pooling, 64), min( |
|
self.eval_freq_pooling, 8 |
|
) |
|
|
|
self.avgpooling = nn.AvgPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
self.maxpooling = nn.MaxPool2d( |
|
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) |
|
) |
|
pooled = ( |
|
self.avgpooling(representation) + self.maxpooling(representation) |
|
) / 2 |
|
pooled = pooled.flatten(2).transpose(1, 2) |
|
return pooled, time_pool, freq_pool |
|
|
|
def regularization(self, x): |
|
assert x.size(-1) == 768 |
|
x = F.normalize(x, p=2, dim=-1) |
|
return x |
|
|
|
|
|
def forward(self, batch): |
|
assert batch.size(-2) == 1024 and batch.size(-1) == 128 |
|
|
|
if self.device is None: |
|
self.device = batch.device |
|
|
|
batch = batch.unsqueeze(1) |
|
|
|
with torch.no_grad(): |
|
representation = self.audiomae( |
|
batch, |
|
mask_ratio=self.mask_ratio, |
|
no_mask=self.no_audiomae_mask, |
|
no_average=self.no_audiomae_average, |
|
) |
|
representation, time_pool, freq_pool = self.pool(representation) |
|
if self.use_reg: |
|
representation = self.regularization(representation) |
|
pool_condition_token = self.pooling_tokens( |
|
time_pool, representation.size(0) |
|
).to(representation.device) |
|
representation = torch.cat([representation, pool_condition_token], dim=1) |
|
|
|
return [ |
|
representation, |
|
torch.ones((representation.size(0), representation.size(1))) |
|
.to(representation.device) |
|
.float(), |
|
] |
|
|
|
|
|
class BeatDownbeatConditionConcat(nn.Module): |
|
def __init__(self, latent_t_size, latent_f_size): |
|
super().__init__() |
|
self.latent_t_size = latent_t_size |
|
self.latent_f_size = latent_f_size |
|
self.device = None |
|
|
|
|
|
def get_unconditional_condition(self, batchsize): |
|
return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( |
|
self.device |
|
) |
|
|
|
|
|
def forward(self, batch): |
|
if self.device is None: |
|
self.device = batch.device |
|
return batch |
|
|
|
|
|
class CLAPAudioEmbeddingClassifierFreev2(nn.Module): |
|
def __init__( |
|
self, |
|
pretrained_path, |
|
sampling_rate=16000, |
|
embed_mode="audio", |
|
amodel="HTSAT-base", |
|
unconditional_prob=0.1, |
|
random_mute=False, |
|
max_random_mute_portion=0.5, |
|
training_mode=True, |
|
): |
|
super().__init__() |
|
self.device = "cpu" |
|
self.precision = "fp32" |
|
self.amodel = amodel |
|
self.tmodel = "roberta" |
|
self.enable_fusion = False |
|
self.fusion_type = "aff_2d" |
|
self.pretrained = pretrained_path |
|
self.embed_mode = embed_mode |
|
self.embed_mode_orig = embed_mode |
|
self.sampling_rate = sampling_rate |
|
self.unconditional_prob = unconditional_prob |
|
self.random_mute = random_mute |
|
self.tokenize = RobertaTokenizer.from_pretrained(config_data["roberta-base"]) |
|
self.max_random_mute_portion = max_random_mute_portion |
|
self.training_mode = training_mode |
|
self.model, self.model_cfg = create_model( |
|
self.amodel, |
|
self.tmodel, |
|
self.pretrained, |
|
precision=self.precision, |
|
device=self.device, |
|
enable_fusion=self.enable_fusion, |
|
fusion_type=self.fusion_type, |
|
) |
|
audio_cfg = self.model_cfg["audio_cfg"] |
|
self.mel_transform = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=audio_cfg["sample_rate"], |
|
n_fft=audio_cfg["window_size"], |
|
win_length=audio_cfg["window_size"], |
|
hop_length=audio_cfg["hop_size"], |
|
center=True, |
|
pad_mode="reflect", |
|
power=2.0, |
|
norm=None, |
|
onesided=True, |
|
n_mels=64, |
|
f_min=audio_cfg["fmin"], |
|
f_max=audio_cfg["fmax"], |
|
) |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.unconditional_token = None |
|
self.model.eval() |
|
|
|
def get_unconditional_condition(self, batchsize): |
|
self.unconditional_token = self.model.get_text_embedding( |
|
self.tokenizer(["", ""]) |
|
)[0:1] |
|
return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) |
|
|
|
def batch_to_list(self, batch): |
|
ret = [] |
|
for i in range(batch.size(0)): |
|
ret.append(batch[i]) |
|
return ret |
|
|
|
def make_decision(self, probability): |
|
if float(torch.rand(1)) < probability: |
|
return True |
|
else: |
|
return False |
|
|
|
def random_uniform(self, start, end): |
|
val = torch.rand(1).item() |
|
return start + (end - start) * val |
|
|
|
def _random_mute(self, waveform): |
|
|
|
t_steps = waveform.size(-1) |
|
for i in range(waveform.size(0)): |
|
mute_size = int( |
|
self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) |
|
) |
|
mute_start = int(self.random_uniform(0, t_steps - mute_size)) |
|
waveform[i, mute_start : mute_start + mute_size] = 0 |
|
return waveform |
|
|
|
def cos_similarity(self, waveform, text): |
|
|
|
original_embed_mode = self.embed_mode |
|
with torch.no_grad(): |
|
self.embed_mode = "audio" |
|
audio_emb = self(waveform.cuda()) |
|
self.embed_mode = "text" |
|
text_emb = self(text) |
|
similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) |
|
self.embed_mode = original_embed_mode |
|
return similarity.squeeze() |
|
|
|
def build_unconditional_emb(self): |
|
self.unconditional_token = self.model.get_text_embedding( |
|
self.tokenizer(["", ""]) |
|
)[0:1] |
|
|
|
def forward(self, batch): |
|
|
|
|
|
if self.model.training == True and not self.training_mode: |
|
print( |
|
"The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." |
|
) |
|
self.model, self.model_cfg = create_model( |
|
self.amodel, |
|
self.tmodel, |
|
self.pretrained, |
|
precision=self.precision, |
|
device="cuda", |
|
enable_fusion=self.enable_fusion, |
|
fusion_type=self.fusion_type, |
|
) |
|
for p in self.model.parameters(): |
|
p.requires_grad = False |
|
self.model.eval() |
|
|
|
if self.unconditional_token is None: |
|
self.build_unconditional_emb() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.embed_mode == "audio": |
|
if not self.training: |
|
print("INFO: clap model calculate the audio embedding as condition") |
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.sampling_rate != 48000: |
|
batch = torchaudio.functional.resample( |
|
batch, orig_freq=self.sampling_rate, new_freq=48000 |
|
) |
|
|
|
audio_data = batch.squeeze(1) |
|
mel = self.mel_transform(audio_data) |
|
audio_dict = get_audio_features( |
|
audio_data, |
|
mel, |
|
480000, |
|
data_truncating="fusion", |
|
data_filling="repeatpad", |
|
audio_cfg=self.model_cfg["audio_cfg"], |
|
) |
|
|
|
embed = self.model.get_audio_embedding(audio_dict) |
|
elif self.embed_mode == "text": |
|
with torch.no_grad(): |
|
|
|
text_data = self.tokenizer(batch) |
|
|
|
if isinstance(batch, str) or ( |
|
isinstance(batch, list) and len(batch) == 1 |
|
): |
|
for key in text_data.keys(): |
|
text_data[key] = text_data[key].unsqueeze(0) |
|
|
|
embed = self.model.get_text_embedding(text_data) |
|
|
|
embed = embed.unsqueeze(1) |
|
for i in range(embed.size(0)): |
|
if self.make_decision(self.unconditional_prob): |
|
embed[i] = self.unconditional_token |
|
|
|
return embed.detach() |
|
|
|
def tokenizer(self, text): |
|
result = self.tokenize( |
|
text, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
) |
|
return {k: v.squeeze(0) for k, v in result.items()} |
|
|
|
|
|
if __name__ == "__main__": |
|
model = CLAPAudioEmbeddingClassifierFreev2( |
|
pretrained_path="/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/ckpt/CLAP.pt", |
|
embed_mode="text", |
|
amodel="HTSAT-tiny", |
|
) |
|
|
|
data = ["text", "text"] |
|
res = model(data) |
|
import ipdb |
|
|
|
ipdb.set_trace() |
|
|