Spaces:
Build error
Build error
| # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition | |
| # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn | |
| # Some layers are re-designed for CLAP | |
| import os | |
| os.environ["NUMBA_CACHE_DIR"] = "/tmp/" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchlibrosa.stft import Spectrogram, LogmelFilterBank | |
| from torchlibrosa.augmentation import SpecAugmentation | |
| from .utils import do_mixup, interpolate, pad_framewise_output | |
| from .feature_fusion import iAFF, AFF, DAF | |
| def init_layer(layer): | |
| """Initialize a Linear or Convolutional layer.""" | |
| nn.init.xavier_uniform_(layer.weight) | |
| if hasattr(layer, "bias"): | |
| if layer.bias is not None: | |
| layer.bias.data.fill_(0.0) | |
| def init_bn(bn): | |
| """Initialize a Batchnorm layer.""" | |
| bn.bias.data.fill_(0.0) | |
| bn.weight.data.fill_(1.0) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ConvBlock, self).__init__() | |
| self.conv1 = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1), | |
| bias=False, | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1), | |
| bias=False, | |
| ) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.init_weight() | |
| def init_weight(self): | |
| init_layer(self.conv1) | |
| init_layer(self.conv2) | |
| init_bn(self.bn1) | |
| init_bn(self.bn2) | |
| def forward(self, input, pool_size=(2, 2), pool_type="avg"): | |
| x = input | |
| x = F.relu_(self.bn1(self.conv1(x))) | |
| x = F.relu_(self.bn2(self.conv2(x))) | |
| if pool_type == "max": | |
| x = F.max_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == "avg": | |
| x = F.avg_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == "avg+max": | |
| x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
| x2 = F.max_pool2d(x, kernel_size=pool_size) | |
| x = x1 + x2 | |
| else: | |
| raise Exception("Incorrect argument!") | |
| return x | |
| class ConvBlock5x5(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(ConvBlock5x5, self).__init__() | |
| self.conv1 = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(5, 5), | |
| stride=(1, 1), | |
| padding=(2, 2), | |
| bias=False, | |
| ) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.init_weight() | |
| def init_weight(self): | |
| init_layer(self.conv1) | |
| init_bn(self.bn1) | |
| def forward(self, input, pool_size=(2, 2), pool_type="avg"): | |
| x = input | |
| x = F.relu_(self.bn1(self.conv1(x))) | |
| if pool_type == "max": | |
| x = F.max_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == "avg": | |
| x = F.avg_pool2d(x, kernel_size=pool_size) | |
| elif pool_type == "avg+max": | |
| x1 = F.avg_pool2d(x, kernel_size=pool_size) | |
| x2 = F.max_pool2d(x, kernel_size=pool_size) | |
| x = x1 + x2 | |
| else: | |
| raise Exception("Incorrect argument!") | |
| return x | |
| class AttBlock(nn.Module): | |
| def __init__(self, n_in, n_out, activation="linear", temperature=1.0): | |
| super(AttBlock, self).__init__() | |
| self.activation = activation | |
| self.temperature = temperature | |
| self.att = nn.Conv1d( | |
| in_channels=n_in, | |
| out_channels=n_out, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ) | |
| self.cla = nn.Conv1d( | |
| in_channels=n_in, | |
| out_channels=n_out, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ) | |
| self.bn_att = nn.BatchNorm1d(n_out) | |
| self.init_weights() | |
| def init_weights(self): | |
| init_layer(self.att) | |
| init_layer(self.cla) | |
| init_bn(self.bn_att) | |
| def forward(self, x): | |
| # x: (n_samples, n_in, n_time) | |
| norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) | |
| cla = self.nonlinear_transform(self.cla(x)) | |
| x = torch.sum(norm_att * cla, dim=2) | |
| return x, norm_att, cla | |
| def nonlinear_transform(self, x): | |
| if self.activation == "linear": | |
| return x | |
| elif self.activation == "sigmoid": | |
| return torch.sigmoid(x) | |
| class Cnn14(nn.Module): | |
| def __init__( | |
| self, | |
| sample_rate, | |
| window_size, | |
| hop_size, | |
| mel_bins, | |
| fmin, | |
| fmax, | |
| classes_num, | |
| enable_fusion=False, | |
| fusion_type="None", | |
| ): | |
| super(Cnn14, self).__init__() | |
| window = "hann" | |
| center = True | |
| pad_mode = "reflect" | |
| ref = 1.0 | |
| amin = 1e-10 | |
| top_db = None | |
| self.enable_fusion = enable_fusion | |
| self.fusion_type = fusion_type | |
| # Spectrogram extractor | |
| self.spectrogram_extractor = Spectrogram( | |
| n_fft=window_size, | |
| hop_length=hop_size, | |
| win_length=window_size, | |
| window=window, | |
| center=center, | |
| pad_mode=pad_mode, | |
| freeze_parameters=True, | |
| ) | |
| # Logmel feature extractor | |
| self.logmel_extractor = LogmelFilterBank( | |
| sr=sample_rate, | |
| n_fft=window_size, | |
| n_mels=mel_bins, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ref=ref, | |
| amin=amin, | |
| top_db=top_db, | |
| freeze_parameters=True, | |
| ) | |
| # Spec augmenter | |
| self.spec_augmenter = SpecAugmentation( | |
| time_drop_width=64, | |
| time_stripes_num=2, | |
| freq_drop_width=8, | |
| freq_stripes_num=2, | |
| ) | |
| self.bn0 = nn.BatchNorm2d(64) | |
| if (self.enable_fusion) and (self.fusion_type == "channel_map"): | |
| self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) | |
| else: | |
| self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
| self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
| self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
| self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
| self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) | |
| self.fc1 = nn.Linear(2048, 2048, bias=True) | |
| self.fc_audioset = nn.Linear(2048, classes_num, bias=True) | |
| if (self.enable_fusion) and ( | |
| self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] | |
| ): | |
| self.mel_conv1d = nn.Sequential( | |
| nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), | |
| nn.BatchNorm1d(64), # No Relu | |
| ) | |
| if self.fusion_type == "daf_1d": | |
| self.fusion_model = DAF() | |
| elif self.fusion_type == "aff_1d": | |
| self.fusion_model = AFF(channels=64, type="1D") | |
| elif self.fusion_type == "iaff_1d": | |
| self.fusion_model = iAFF(channels=64, type="1D") | |
| if (self.enable_fusion) and ( | |
| self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] | |
| ): | |
| self.mel_conv2d = nn.Sequential( | |
| nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| ) | |
| if self.fusion_type == "daf_2d": | |
| self.fusion_model = DAF() | |
| elif self.fusion_type == "aff_2d": | |
| self.fusion_model = AFF(channels=64, type="2D") | |
| elif self.fusion_type == "iaff_2d": | |
| self.fusion_model = iAFF(channels=64, type="2D") | |
| self.init_weight() | |
| def init_weight(self): | |
| init_bn(self.bn0) | |
| init_layer(self.fc1) | |
| init_layer(self.fc_audioset) | |
| def forward(self, input, mixup_lambda=None, device=None): | |
| """ | |
| Input: (batch_size, data_length)""" | |
| if self.enable_fusion and input["longer"].sum() == 0: | |
| # if no audio is longer than 10s, then randomly select one audio to be longer | |
| input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True | |
| if not self.enable_fusion: | |
| x = self.spectrogram_extractor( | |
| input["waveform"].to(device=device, non_blocking=True) | |
| ) # (batch_size, 1, time_steps, freq_bins) | |
| x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| else: | |
| longer_list = input["longer"].to(device=device, non_blocking=True) | |
| x = input["mel_fusion"].to(device=device, non_blocking=True) | |
| longer_list_idx = torch.where(longer_list)[0] | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: | |
| new_x = x[:, 0:1, :, :].clone().contiguous() | |
| # local processing | |
| if len(longer_list_idx) > 0: | |
| fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() | |
| FB, FC, FT, FF = fusion_x_local.size() | |
| fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) | |
| fusion_x_local = torch.permute( | |
| fusion_x_local, (0, 2, 1) | |
| ).contiguous() | |
| fusion_x_local = self.mel_conv1d(fusion_x_local) | |
| fusion_x_local = fusion_x_local.view( | |
| FB, FC, FF, fusion_x_local.size(-1) | |
| ) | |
| fusion_x_local = ( | |
| torch.permute(fusion_x_local, (0, 2, 1, 3)) | |
| .contiguous() | |
| .flatten(2) | |
| ) | |
| if fusion_x_local.size(-1) < FT: | |
| fusion_x_local = torch.cat( | |
| [ | |
| fusion_x_local, | |
| torch.zeros( | |
| (FB, FF, FT - fusion_x_local.size(-1)), | |
| device=device, | |
| ), | |
| ], | |
| dim=-1, | |
| ) | |
| else: | |
| fusion_x_local = fusion_x_local[:, :, :FT] | |
| # 1D fusion | |
| new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() | |
| new_x[longer_list_idx] = self.fusion_model( | |
| new_x[longer_list_idx], fusion_x_local | |
| ) | |
| x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] | |
| else: | |
| x = new_x | |
| elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: | |
| x = x # no change | |
| if self.training: | |
| x = self.spec_augmenter(x) | |
| # Mixup on spectrogram | |
| if self.training and mixup_lambda is not None: | |
| x = do_mixup(x, mixup_lambda) | |
| if (self.enable_fusion) and ( | |
| self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] | |
| ): | |
| global_x = x[:, 0:1, :, :] | |
| # global processing | |
| B, C, H, W = global_x.shape | |
| global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") | |
| if len(longer_list_idx) > 0: | |
| local_x = x[longer_list_idx, 1:, :, :].contiguous() | |
| TH = global_x.size(-2) | |
| # local processing | |
| B, C, H, W = local_x.shape | |
| local_x = local_x.view(B * C, 1, H, W) | |
| local_x = self.mel_conv2d(local_x) | |
| local_x = local_x.view( | |
| B, C, local_x.size(1), local_x.size(2), local_x.size(3) | |
| ) | |
| local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) | |
| TB, TC, _, TW = local_x.size() | |
| if local_x.size(-2) < TH: | |
| local_x = torch.cat( | |
| [ | |
| local_x, | |
| torch.zeros( | |
| (TB, TC, TH - local_x.size(-2), TW), | |
| device=global_x.device, | |
| ), | |
| ], | |
| dim=-2, | |
| ) | |
| else: | |
| local_x = local_x[:, :, :TH, :] | |
| global_x[longer_list_idx] = self.fusion_model( | |
| global_x[longer_list_idx], local_x | |
| ) | |
| x = global_x | |
| else: | |
| x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = torch.mean(x, dim=3) | |
| latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x = latent_x1 + latent_x2 | |
| latent_x = latent_x.transpose(1, 2) | |
| latent_x = F.relu_(self.fc1(latent_x)) | |
| latent_output = interpolate(latent_x, 32) | |
| (x1, _) = torch.max(x, dim=2) | |
| x2 = torch.mean(x, dim=2) | |
| x = x1 + x2 | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu_(self.fc1(x)) | |
| embedding = F.dropout(x, p=0.5, training=self.training) | |
| clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
| output_dict = { | |
| "clipwise_output": clipwise_output, | |
| "embedding": embedding, | |
| "fine_grained_embedding": latent_output, | |
| } | |
| return output_dict | |
| class Cnn6(nn.Module): | |
| def __init__( | |
| self, | |
| sample_rate, | |
| window_size, | |
| hop_size, | |
| mel_bins, | |
| fmin, | |
| fmax, | |
| classes_num, | |
| enable_fusion=False, | |
| fusion_type="None", | |
| ): | |
| super(Cnn6, self).__init__() | |
| window = "hann" | |
| center = True | |
| pad_mode = "reflect" | |
| ref = 1.0 | |
| amin = 1e-10 | |
| top_db = None | |
| self.enable_fusion = enable_fusion | |
| self.fusion_type = fusion_type | |
| # Spectrogram extractor | |
| self.spectrogram_extractor = Spectrogram( | |
| n_fft=window_size, | |
| hop_length=hop_size, | |
| win_length=window_size, | |
| window=window, | |
| center=center, | |
| pad_mode=pad_mode, | |
| freeze_parameters=True, | |
| ) | |
| # Logmel feature extractor | |
| self.logmel_extractor = LogmelFilterBank( | |
| sr=sample_rate, | |
| n_fft=window_size, | |
| n_mels=mel_bins, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ref=ref, | |
| amin=amin, | |
| top_db=top_db, | |
| freeze_parameters=True, | |
| ) | |
| # Spec augmenter | |
| self.spec_augmenter = SpecAugmentation( | |
| time_drop_width=64, | |
| time_stripes_num=2, | |
| freq_drop_width=8, | |
| freq_stripes_num=2, | |
| ) | |
| self.bn0 = nn.BatchNorm2d(64) | |
| self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) | |
| self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) | |
| self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) | |
| self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) | |
| self.fc1 = nn.Linear(512, 512, bias=True) | |
| self.fc_audioset = nn.Linear(512, classes_num, bias=True) | |
| self.init_weight() | |
| def init_weight(self): | |
| init_bn(self.bn0) | |
| init_layer(self.fc1) | |
| init_layer(self.fc_audioset) | |
| def forward(self, input, mixup_lambda=None, device=None): | |
| """ | |
| Input: (batch_size, data_length)""" | |
| x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) | |
| x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| if self.training: | |
| x = self.spec_augmenter(x) | |
| # Mixup on spectrogram | |
| if self.training and mixup_lambda is not None: | |
| x = do_mixup(x, mixup_lambda) | |
| x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = torch.mean(x, dim=3) | |
| latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x = latent_x1 + latent_x2 | |
| latent_x = latent_x.transpose(1, 2) | |
| latent_x = F.relu_(self.fc1(latent_x)) | |
| latent_output = interpolate(latent_x, 16) | |
| (x1, _) = torch.max(x, dim=2) | |
| x2 = torch.mean(x, dim=2) | |
| x = x1 + x2 | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu_(self.fc1(x)) | |
| embedding = F.dropout(x, p=0.5, training=self.training) | |
| clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
| output_dict = { | |
| "clipwise_output": clipwise_output, | |
| "embedding": embedding, | |
| "fine_grained_embedding": latent_output, | |
| } | |
| return output_dict | |
| class Cnn10(nn.Module): | |
| def __init__( | |
| self, | |
| sample_rate, | |
| window_size, | |
| hop_size, | |
| mel_bins, | |
| fmin, | |
| fmax, | |
| classes_num, | |
| enable_fusion=False, | |
| fusion_type="None", | |
| ): | |
| super(Cnn10, self).__init__() | |
| window = "hann" | |
| center = True | |
| pad_mode = "reflect" | |
| ref = 1.0 | |
| amin = 1e-10 | |
| top_db = None | |
| self.enable_fusion = enable_fusion | |
| self.fusion_type = fusion_type | |
| # Spectrogram extractor | |
| self.spectrogram_extractor = Spectrogram( | |
| n_fft=window_size, | |
| hop_length=hop_size, | |
| win_length=window_size, | |
| window=window, | |
| center=center, | |
| pad_mode=pad_mode, | |
| freeze_parameters=True, | |
| ) | |
| # Logmel feature extractor | |
| self.logmel_extractor = LogmelFilterBank( | |
| sr=sample_rate, | |
| n_fft=window_size, | |
| n_mels=mel_bins, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ref=ref, | |
| amin=amin, | |
| top_db=top_db, | |
| freeze_parameters=True, | |
| ) | |
| # Spec augmenter | |
| self.spec_augmenter = SpecAugmentation( | |
| time_drop_width=64, | |
| time_stripes_num=2, | |
| freq_drop_width=8, | |
| freq_stripes_num=2, | |
| ) | |
| self.bn0 = nn.BatchNorm2d(64) | |
| self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) | |
| self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) | |
| self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) | |
| self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) | |
| self.fc1 = nn.Linear(1024, 1024, bias=True) | |
| self.fc_audioset = nn.Linear(1024, classes_num, bias=True) | |
| self.init_weight() | |
| def init_weight(self): | |
| init_bn(self.bn0) | |
| init_layer(self.fc1) | |
| init_layer(self.fc_audioset) | |
| def forward(self, input, mixup_lambda=None, device=None): | |
| """ | |
| Input: (batch_size, data_length)""" | |
| x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) | |
| x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) | |
| x = x.transpose(1, 3) | |
| x = self.bn0(x) | |
| x = x.transpose(1, 3) | |
| if self.training: | |
| x = self.spec_augmenter(x) | |
| # Mixup on spectrogram | |
| if self.training and mixup_lambda is not None: | |
| x = do_mixup(x, mixup_lambda) | |
| x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") | |
| x = F.dropout(x, p=0.2, training=self.training) | |
| x = torch.mean(x, dim=3) | |
| latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) | |
| latent_x = latent_x1 + latent_x2 | |
| latent_x = latent_x.transpose(1, 2) | |
| latent_x = F.relu_(self.fc1(latent_x)) | |
| latent_output = interpolate(latent_x, 32) | |
| (x1, _) = torch.max(x, dim=2) | |
| x2 = torch.mean(x, dim=2) | |
| x = x1 + x2 | |
| x = F.dropout(x, p=0.5, training=self.training) | |
| x = F.relu_(self.fc1(x)) | |
| embedding = F.dropout(x, p=0.5, training=self.training) | |
| clipwise_output = torch.sigmoid(self.fc_audioset(x)) | |
| output_dict = { | |
| "clipwise_output": clipwise_output, | |
| "embedding": embedding, | |
| "fine_grained_embedding": latent_output, | |
| } | |
| return output_dict | |
| def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): | |
| try: | |
| ModelProto = eval(audio_cfg.model_name) | |
| model = ModelProto( | |
| sample_rate=audio_cfg.sample_rate, | |
| window_size=audio_cfg.window_size, | |
| hop_size=audio_cfg.hop_size, | |
| mel_bins=audio_cfg.mel_bins, | |
| fmin=audio_cfg.fmin, | |
| fmax=audio_cfg.fmax, | |
| classes_num=audio_cfg.class_num, | |
| enable_fusion=enable_fusion, | |
| fusion_type=fusion_type, | |
| ) | |
| return model | |
| except: | |
| raise RuntimeError( | |
| f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." | |
| ) | |