Brilleslangen
App for cleavage site prediction.
aec9df8
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig
import evaluate
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import wandb
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
class CleavageSiteModel(nn.Module):
def __init__(self, base_model, num_classes=75, class_weights=None):
super().__init__()
self.model = EsmModel.from_pretrained(base_model)
self.classifier = nn.Linear(self.model.config.hidden_size, num_classes)
if class_weights is not None:
# Create full-length weights tensor with zeros
weight_tensor = torch.zeros(num_classes)
for class_idx, weight in class_weights.items():
weight_tensor[class_idx] = weight
self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor)
else:
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0]
logits = self.classifier(cls_output)
if labels is not None:
loss = self.loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
else:
return {"logits": logits}
def compute_metrics(eval_pred):
# Computes classification metrics including overall accuracy and per-class accuracy.
logits, labels = eval_pred # Extract model outputs and labels
predictions = np.argmax(logits, axis=1) # Get predicted class
# Compute overall accuracy
accuracy = accuracy_score(labels, predictions)
report = classification_report(labels, predictions, digits=4)
wandb.log({"classification_report": wandb.Html(report.replace('\n', '<br>'))})
# Compute per-class accuracy
unique_classes = np.unique(labels)
per_class_acc = {}
for cls in unique_classes:
class_mask = labels == cls # Select samples belonging to this class
per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean()
# Log metrics
wandb.log({"overall_accuracy": accuracy, **per_class_acc})
return {"accuracy": accuracy, **per_class_acc}