jadechoghari's picture
add model
9b9e0ee verified
raw
history blame
No virus
5.37 kB
"""
Reference Repo: https://github.com/facebookresearch/AudioMAE
"""
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
import qa_mdt.audioldm_train.modules.audiomae.models_vit as models_vit
import qa_mdt.audioldm_train.modules.audiomae.models_mae as models_mae
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
class PatchEmbed_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
) # with overlapped patches
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h * w
def get_output_shape(self, img_size):
# todo: don't be lazy..
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
def __init__(
self,
):
super().__init__()
model = models_vit.__dict__["vit_base_patch16"](
num_classes=527,
drop_path_rate=0.1,
global_pool=True,
mask_2d=True,
use_custom_patch=False,
)
img_size = (1024, 128)
emb_dim = 768
model.patch_embed = PatchEmbed_new(
img_size=img_size,
patch_size=(16, 16),
in_chans=1,
embed_dim=emb_dim,
stride=16,
)
num_patches = model.patch_embed.num_patches
# num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
model.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
) # fixed sin-cos embedding
checkpoint_path = (
"/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
msg = model.load_state_dict(checkpoint["model"], strict=False)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model
def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
"""
x: mel fbank [Batch, 1, T, F]
mask_t_prob: 'T masking ratio (percentage of removed patches).'
mask_f_prob: 'F masking ratio (percentage of removed patches).'
"""
return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
class Vanilla_AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)"""
def __init__(
self,
):
super().__init__()
model = models_mae.__dict__["mae_vit_base_patch16"](
in_chans=1, audio_exp=True, img_size=(1024, 128)
)
checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
msg = model.load_state_dict(checkpoint["model"], strict=False)
# Skip the missing keys of decoder modules (not required)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model.eval()
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
"""
x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
mask_ratio: 'masking ratio (percentage of removed patches).'
"""
with torch.no_grad():
# embed: [B, 513, 768] for mask_ratio=0.0
if no_mask:
if no_average:
raise RuntimeError("This function is deprecated")
embed = self.model.forward_encoder_no_random_mask_no_average(
x
) # mask_ratio
else:
embed = self.model.forward_encoder_no_mask(x) # mask_ratio
else:
raise RuntimeError("This function is deprecated")
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
return embed
if __name__ == "__main__":
model = Vanilla_AudioMAE().cuda()
input = torch.randn(4, 1, 1024, 128).cuda()
print("The first run")
embed = model(input, mask_ratio=0.0, no_mask=True)
print(embed)
print("The second run")
embed = model(input, mask_ratio=0.0)
print(embed)