|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch |
|
|
from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig |
|
|
import evaluate |
|
|
import numpy as np |
|
|
from sklearn.metrics import accuracy_score, classification_report |
|
|
import wandb |
|
|
|
|
|
accuracy_metric = evaluate.load("accuracy") |
|
|
precision_metric = evaluate.load("precision") |
|
|
recall_metric = evaluate.load("recall") |
|
|
f1_metric = evaluate.load("f1") |
|
|
|
|
|
|
|
|
class CleavageSiteModel(nn.Module): |
|
|
def __init__(self, base_model, num_classes=75, class_weights=None): |
|
|
super().__init__() |
|
|
self.model = EsmModel.from_pretrained(base_model) |
|
|
self.classifier = nn.Linear(self.model.config.hidden_size, num_classes) |
|
|
|
|
|
if class_weights is not None: |
|
|
|
|
|
weight_tensor = torch.zeros(num_classes) |
|
|
for class_idx, weight in class_weights.items(): |
|
|
weight_tensor[class_idx] = weight |
|
|
self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor) |
|
|
else: |
|
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
cls_output = outputs.last_hidden_state[:, 0] |
|
|
logits = self.classifier(cls_output) |
|
|
|
|
|
if labels is not None: |
|
|
loss = self.loss_fn(logits, labels) |
|
|
return {"loss": loss, "logits": logits} |
|
|
else: |
|
|
return {"logits": logits} |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
|
|
|
|
|
|
logits, labels = eval_pred |
|
|
predictions = np.argmax(logits, axis=1) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
|
|
|
report = classification_report(labels, predictions, digits=4) |
|
|
wandb.log({"classification_report": wandb.Html(report.replace('\n', '<br>'))}) |
|
|
|
|
|
|
|
|
unique_classes = np.unique(labels) |
|
|
per_class_acc = {} |
|
|
for cls in unique_classes: |
|
|
class_mask = labels == cls |
|
|
per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean() |
|
|
|
|
|
|
|
|
wandb.log({"overall_accuracy": accuracy, **per_class_acc}) |
|
|
|
|
|
return {"accuracy": accuracy, **per_class_acc} |
|
|
|