import torch from torch import nn from transformers import DistilBertModel, DistilBertTokenizer class MultilabelClassifier(nn.Module): """Base model for multilabel classification supporting different backbones""" def __init__(self, model_name, num_labels, dropout=0.1): super(MultilabelClassifier, self).__init__() self.backbone = DistilBertModel.from_pretrained(model_name) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(768, num_labels) self.sigmoid = nn.Sigmoid() def forward(self, input_ids, attention_mask): outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0] x = self.dropout(pooled_output) logits = self.classifier(x) return self.sigmoid(logits)