Spaces:
Build error
Build error
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from .model import MLPLayers | |
| class LinearProbe(nn.Module): | |
| def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): | |
| """ | |
| Args: | |
| model: nn.Module | |
| mlp: bool, if True, then use the MLP layer as the linear probe module | |
| freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe | |
| in_ch: int, the output channel from CLAP model | |
| out_ch: int, the output channel from linear probe (class_num) | |
| act: torch.nn.functional, the activation function before the loss function | |
| """ | |
| super().__init__() | |
| in_ch = 512 | |
| self.clap_model = model | |
| self.clap_model.text_branch = None # to save memory | |
| self.freeze = freeze | |
| if mlp: | |
| self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) | |
| else: | |
| self.lp_layer = nn.Linear(in_ch, out_ch) | |
| if self.freeze: | |
| for param in self.clap_model.parameters(): | |
| param.requires_grad = False | |
| if act == "None": | |
| self.act = None | |
| elif act == "relu": | |
| self.act = nn.ReLU() | |
| elif act == "elu": | |
| self.act = nn.ELU() | |
| elif act == "prelu": | |
| self.act = nn.PReLU(num_parameters=in_ch) | |
| elif act == "softmax": | |
| self.act = nn.Softmax(dim=-1) | |
| elif act == "sigmoid": | |
| self.act = nn.Sigmoid() | |
| def forward(self, x, mix_lambda=None, device=None): | |
| """ | |
| Args: | |
| x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list | |
| mix_lambda: torch.tensor [batch], the mixup lambda | |
| Returns: | |
| class_prob: torch.tensor [batch, class_num] | |
| """ | |
| # batchnorm cancel grandient | |
| if self.freeze: | |
| self.clap_model.eval() | |
| x = self.clap_model.audio_projection( | |
| self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ | |
| "embedding" | |
| ] | |
| ) | |
| out = self.lp_layer(x) | |
| if self.act is not None: | |
| out = self.act(out) | |
| return out | |