import torch import torch.nn as nn from transformers import PreTrainedModel from typing import List from .config import LidirlCNNConfig def torch_max_no_pads(model_out, lengths): indices = torch.arange(model_out.size(1)).to(model_out.device) mask = (indices < lengths.view(-1, 1)).unsqueeze(-1).expand(model_out.size()) model_out = torch.where(mask, model_out, torch.tensor(-1e9)) max_pool = torch.max(model_out, 1)[0] return max_pool class TransposeModule(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.transpose(1, 2) class ProjectionLayer(nn.Module): """ Noise-aware labels layer or traditional linear projection """ def __init__(self, hidden_dim, label_size, montecarlo_layer=False): super().__init__() self.montecarlo_layer = montecarlo_layer if montecarlo_layer: self.proj = MCSoftmaxDenseFA(hidden_dim, label_size, 1, logits_only=True) else: self.proj = nn.Linear(hidden_dim, label_size) def forward(self, x): return self.proj(x) class ConvolutionalBlock( nn.Module, ): """ Convolutional block https://jonathanbgn.com/2021/09/30/illustrated-wav2vec-2.html """ def __init__(self, embed_dim : int, channels : List[int], kernels : List[int], strides : List[int]): super(ConvolutionalBlock, self).__init__() layers = [] self.embed_dim = embed_dim input_dimension = embed_dim for channel, kernel, stride in zip(channels, kernels, strides): next_layer = nn.Conv1d( in_channels = input_dimension, out_channels = channel, kernel_size = kernel, stride = stride, padding = 'valid', # we handle the padding ourselves in the forward function ) input_dimension = channel layers.append(TransposeModule()) layers.append(next_layer) layers.append(TransposeModule()) layers.append(nn.LayerNorm(channel, elementwise_affine=True)) layers.append(nn.GELU()) layers.append(nn.Dropout(0.1)) self.model = nn.Sequential(*layers) self.output_dim = channels[-1] self.min_length = 1 for kernel, stride in zip(kernels[::-1], strides[::-1]): self.min_length = ((self.min_length - 1) * stride) + kernel def forward(self, inputs, lengths): # this is our padding trick instead of consistent padding if inputs.size(1) < self.min_length: pads = torch.zeros((inputs.size(0), self.min_length - inputs.size(1), self.embed_dim), device=inputs.device) inputs = torch.cat((inputs, pads), dim=1) outputs = self.model(inputs) for layer_i in range(1, len(self.model), 6): lengths = torch.floor(((lengths - self.model[layer_i].kernel_size[0]) / self.model[layer_i].stride[0]) + 1).to(lengths.device, dtype=torch.long) lengths[lengths < 1] = 1 return outputs, lengths class LidirlCNN(PreTrainedModel): """ Defines the Lidirl CNN MODEL """ config_class = LidirlCNNConfig def __init__(self, config): super().__init__(config) self.encoder = ConvolutionalBlock(config.embed_dim, config.channels, config.kernels, config.strides) self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim) self.proj = ProjectionLayer(self.encoder.output_dim, config.label_size, config.montecarlo_layer) self.label_size = config.label_size self.max_length = config.max_length self.multilabel = config.multilabel self.monte_carlo = config.montecarlo_layer self.labels = ["" for _ in config.labels] for key, value in config.labels.items(): self.labels[value] = key def forward(self, inputs, lengths): inputs = inputs[:, :self.max_length] lengths = lengths.clamp(max=self.max_length) embeddings = self.embed_layer(inputs) encoding, lengths = self.encoder(embeddings, lengths=lengths) max_pool = torch_max_no_pads(encoding, lengths) projection = self.proj(max_pool) return projection def __call__(self, inputs, lengths): # this is inference only model with torch.no_grad(): logits = self.forward(inputs, lengths) if self.multilabel: probs = torch.sigmoid(logits) else: probs = torch.softmax(logits, dim=-1) return probs def predict(self, inputs, lengths, threshold=0.5, top_k=None): probs = self.__call__(inputs, lengths) if top_k is not None and top_k > 0: top_k_preds = torch.topk(probs, top_k, dim=1) pred_labels = [] for pred, prob in zip(top_k_preds.indices, top_k_preds.values): pred_labels.append([(self.labels[p.item()], pr.item()) for (p, pr) in zip(pred, prob)]) return pred_labels if self.multilabel: batch_idx, label_idx = torch.where(probs > threshold) output = [[] for _ in range(len(inputs))] for batch, label in zip(batch_idx, label_idx): label_string = self.labels output[batch.item()].append( (self.labels[label.item()], probs[batch, label]) ) return output