|
import torch |
|
import torch.nn.functional as F |
|
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score |
|
import numpy as np |
|
|
|
class GraphMetrics: |
|
"""Production-ready evaluation metrics""" |
|
|
|
@staticmethod |
|
def accuracy(pred, target): |
|
"""Classification accuracy""" |
|
pred_labels = pred.argmax(dim=1) |
|
return (pred_labels == target).float().mean().item() |
|
|
|
@staticmethod |
|
def f1_score_macro(pred, target): |
|
"""Macro F1 score""" |
|
pred_labels = pred.argmax(dim=1).cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
return f1_score(target_labels, pred_labels, average='macro') |
|
|
|
@staticmethod |
|
def f1_score_micro(pred, target): |
|
"""Micro F1 score""" |
|
pred_labels = pred.argmax(dim=1).cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
return f1_score(target_labels, pred_labels, average='micro') |
|
|
|
@staticmethod |
|
def roc_auc(pred, target, num_classes): |
|
"""ROC AUC for multi-class""" |
|
if num_classes == 2: |
|
|
|
pred_probs = F.softmax(pred, dim=1)[:, 1].cpu().numpy() |
|
target_labels = target.cpu().numpy() |
|
return roc_auc_score(target_labels, pred_probs) |
|
else: |
|
|
|
pred_probs = F.softmax(pred, dim=1).cpu().numpy() |
|
target_onehot = F.one_hot(target, num_classes).cpu().numpy() |
|
return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro') |
|
|
|
@staticmethod |
|
def evaluate_node_classification(model, data, mask, device): |
|
"""Comprehensive node classification evaluation""" |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
data = data.to(device) |
|
h = model(data.x, data.edge_index) |
|
|
|
|
|
if hasattr(model, 'classifier'): |
|
pred = model.classifier(h) |
|
else: |
|
|
|
return {'embeddings': h[mask].cpu()} |
|
|
|
pred_masked = pred[mask] |
|
target_masked = data.y[mask] |
|
|
|
metrics = { |
|
'accuracy': GraphMetrics.accuracy(pred_masked, target_masked), |
|
'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked), |
|
'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked), |
|
} |
|
|
|
|
|
try: |
|
num_classes = pred.size(1) |
|
metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked, num_classes) |
|
except: |
|
pass |
|
|
|
return metrics |
|
|
|
@staticmethod |
|
def evaluate_graph_classification(model, dataloader, device): |
|
"""Comprehensive graph classification evaluation""" |
|
model.eval() |
|
|
|
all_preds = [] |
|
all_targets = [] |
|
|
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
batch = batch.to(device) |
|
h = model(batch.x, batch.edge_index, batch.batch) |
|
|
|
|
|
graph_h = model.get_graph_embedding(h, batch.batch) |
|
|
|
if hasattr(model, 'classifier'): |
|
pred = model.classifier(graph_h) |
|
all_preds.append(pred) |
|
all_targets.append(batch.y) |
|
|
|
if all_preds: |
|
all_preds = torch.cat(all_preds, dim=0) |
|
all_targets = torch.cat(all_targets, dim=0) |
|
|
|
metrics = { |
|
'accuracy': GraphMetrics.accuracy(all_preds, all_targets), |
|
'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets), |
|
'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets), |
|
} |
|
|
|
try: |
|
num_classes = all_preds.size(1) |
|
metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets, num_classes) |
|
except: |
|
pass |
|
|
|
return metrics |
|
|
|
return {'error': 'No predictions generated'} |