jadechoghari's picture
add model
9b9e0ee verified
raw
history blame
12.1 kB
import torch
import torch.nn as nn
import pytorch_lightning as pl
from qa_mdt.audioldm_train.utilities.model_util import (
exists,
default,
mean_flat,
count_params,
instantiate_from_config,
)
from transformers import GPT2Config, GPT2Model
import torch.optim.lr_scheduler as lr_scheduler
class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[
nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)
]
)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_rate)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
class CLAP2AudioMAE(pl.LightningModule):
def __init__(
self,
sequence_gen_length,
base_learning_rate,
cond_stage_config,
use_audiomae_linear=False,
**kwargs
):
super().__init__()
assert use_audiomae_linear == False
self.learning_rate = base_learning_rate
self.cond_stage_config = cond_stage_config
self.use_audiomae_linear = use_audiomae_linear
self.mae_token_num = sequence_gen_length # 4*4 pooling of the audiomae latent
self.cond_stage_models = nn.ModuleList([])
self.instantiate_cond_stage(cond_stage_config)
self.model = GPT2Model.from_pretrained("gpt2")
self.linear_clap = nn.Linear(512, 768)
if use_audiomae_linear:
# self.linear_audiomae = nn.Linear(768, 768) # TODO remove linear_audiomae
self.linear_audiomae = None # TODO remove linear_audiomae
self.loss_fn = nn.MSELoss()
self.logger_save_dir = None
self.logger_exp_name = None
self.logger_exp_group_name = None
self.logger_version = None
def set_log_dir(self, save_dir, exp_group_name, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_group_name = exp_group_name
self.logger_exp_name = exp_name
def cfg_uncond(self, batch_size):
unconditional_conditioning = {}
for key in self.cond_stage_model_metadata:
model_idx = self.cond_stage_model_metadata[key]["model_idx"]
unconditional_conditioning[key] = self.cond_stage_models[
model_idx
].get_unconditional_condition(batch_size)
assert (
"crossattn_audiomae_pooled" in unconditional_conditioning.keys()
), "The module is not initialized with AudioMAE"
unconditional_conditioning[
"crossattn_clap_to_audiomae_feature"
] = unconditional_conditioning["crossattn_audiomae_pooled"]
return unconditional_conditioning
def configure_optimizers(self):
lr = float(self.learning_rate)
params = list(self.model.parameters()) + list(self.linear_clap.parameters())
if self.use_audiomae_linear:
params += list(self.linear_audiomae.parameters())
opt = torch.optim.AdamW(params, lr=lr)
scheduler = lr_scheduler.StepLR(opt, step_size=1, gamma=0.9)
return [opt], [scheduler]
def training_step(self, batch, batch_idx=None, cond_dict=None):
if cond_dict is None:
cond_dict = self.get_input(batch)
input_embeds, target_embeds = (
cond_dict["film_clap_cond1"],
cond_dict["crossattn_audiomae_pooled"][0],
)
# Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
if "crossattn_audiomae_pooled_44" in cond_dict.keys():
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
if self.use_audiomae_linear:
input_embeds = torch.cat(
[self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)],
dim=1,
)
else:
input_embeds = torch.cat(
[self.linear_clap(input_embeds), target_embeds], dim=1
)
output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"]
target = target_embeds
output = output_embeds[:, :-1]
loss = self.loss_fn(output, target)
self.log(
"train/loss_clap_2_audiomae",
loss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
sync_dist=True,
)
self.log(
"global_step_audiomae",
float(self.global_step),
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
sync_dist=True,
)
return loss
def generate(self, batch, cond_dict=None, no_grad=False):
if cond_dict is None:
cond_dict = self.get_input(batch)
input_embeds = cond_dict["film_clap_cond1"]
steps = self.mae_token_num
if no_grad:
with torch.no_grad():
model_input = self.linear_clap(input_embeds)
for _ in range(steps):
output = self.model(inputs_embeds=model_input)["last_hidden_state"]
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
else:
model_input = self.linear_clap(input_embeds)
for _ in range(steps):
output = self.model(inputs_embeds=model_input)["last_hidden_state"]
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
return model_input[:, 1:], cond_dict
# def on_validation_epoch_start(self) -> None:
# # Use text as condition during validation
# for key in self.cond_stage_model_metadata.keys():
# metadata = self.cond_stage_model_metadata[key]
# model_idx, cond_stage_key, conditioning_key = metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"]
# # If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation
# # if(isinstance(self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2)):
# # self.cond_stage_model_metadata[key]["cond_stage_key_orig"] = self.cond_stage_model_metadata[key]["cond_stage_key"]
# # self.cond_stage_model_metadata[key]["embed_mode_orig"] = self.cond_stage_models[model_idx].embed_mode
# # print("Change the model original cond_keyand embed_mode %s, %s to text during evaluation" % (self.cond_stage_model_metadata[key]["cond_stage_key_orig"], self.cond_stage_model_metadata[key]["embed_mode_orig"]))
# # self.cond_stage_model_metadata[key]["cond_stage_key"] = "text"
# # self.cond_stage_models[model_idx].embed_mode = "text"
# return super().on_validation_epoch_start()
def validation_step(self, batch, batch_idx):
cond_dict = self.get_input(batch)
# cond_dict['film_clap_cond1']: [2,1,512]
# cond_dict['crossattn_audiomae_pooled']: [2, 128, 768]
input_embeds, target_embeds = (
cond_dict["film_clap_cond1"],
cond_dict["crossattn_audiomae_pooled"][0],
)
# Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
if "crossattn_audiomae_pooled_44" in cond_dict.keys():
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
if self.use_audiomae_linear:
input_embeds = torch.cat(
[self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)],
dim=1,
)
else:
input_embeds = torch.cat(
[self.linear_clap(input_embeds), target_embeds], dim=1
)
output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"]
target = target_embeds
output = output_embeds[:, :-1]
loss = self.loss_fn(output, target)
self.log(
"val/loss",
loss,
prog_bar=True,
logger=True,
on_step=True,
sync_dist=True,
on_epoch=True,
)
generation_output, _ = self.generate(batch)
ar_gen_loss = self.loss_fn(generation_output, target)
self.log(
"val/ar_gen_loss",
ar_gen_loss,
prog_bar=True,
logger=True,
on_step=True,
sync_dist=True,
on_epoch=True,
)
return {"loss": loss, "ar_gen_loss": ar_gen_loss}
def get_input_item(self, batch, k):
fname, text, label_indices, waveform, stft, fbank = (
batch["fname"],
batch["text"],
batch["label_vector"],
batch["waveform"],
batch["stft"],
batch["log_mel_spec"],
)
ret = {}
ret["fbank"] = (
fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
)
ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
# ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
ret["text"] = list(text)
ret["fname"] = fname
for key in batch.keys():
if key not in ret.keys():
ret[key] = batch[key]
return ret[k]
def get_input(self, batch):
cond_dict = {}
if len(self.cond_stage_model_metadata.keys()) > 0:
unconditional_cfg = False
for cond_model_key in self.cond_stage_model_metadata.keys():
cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
"cond_stage_key"
]
# if(not self.training):
# if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
# assert cond_stage_key == "text" # CLAP model should use text for evaluation
# The original data for conditioning
xc = self.get_input_item(batch, cond_stage_key)
if type(xc) == torch.Tensor:
xc = xc.to(self.device)
c = self.get_learned_conditioning(
xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
)
cond_dict[cond_model_key] = c
return cond_dict
def instantiate_cond_stage(self, config):
self.cond_stage_model_metadata = {}
for i, cond_model_key in enumerate(config.keys()):
model = instantiate_from_config(config[cond_model_key])
self.cond_stage_models.append(model)
self.cond_stage_model_metadata[cond_model_key] = {
"model_idx": i,
"cond_stage_key": config[cond_model_key]["cond_stage_key"],
"conditioning_key": config[cond_model_key]["conditioning_key"],
}
def get_learned_conditioning(self, c, key, unconditional_cfg):
assert key in self.cond_stage_model_metadata.keys()
# Classifier-free guidance
if not unconditional_cfg:
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
](c)
else:
if isinstance(c, torch.Tensor):
batchsize = c.size(0)
elif isinstance(c, list):
batchsize = len(c)
else:
raise NotImplementedError()
c = self.cond_stage_models[
self.cond_stage_model_metadata[key]["model_idx"]
].get_unconditional_condition(batchsize)
return c