File size: 4,858 Bytes
6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 6d0498a 850d736 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np
class GraphMetrics:
"""Production-ready evaluation metrics - device safe"""
@staticmethod
def accuracy(pred, target):
"""Classification accuracy"""
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
return (pred_labels == target).float().mean().item()
@staticmethod
def f1_score_macro(pred, target):
"""Macro F1 score"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
return f1_score(target_labels, pred_labels, average='macro', zero_division=0)
except:
return 0.0
@staticmethod
def f1_score_micro(pred, target):
"""Micro F1 score"""
try:
if pred.dim() > 1:
pred_labels = pred.argmax(dim=1)
else:
pred_labels = pred
pred_labels = pred_labels.cpu().numpy()
target_labels = target.cpu().numpy()
return f1_score(target_labels, pred_labels, average='micro', zero_division=0)
except:
return 0.0
@staticmethod
def evaluate_node_classification(model, data, mask, device):
"""Comprehensive node classification evaluation"""
model.eval()
try:
with torch.no_grad():
# Ensure data is on correct device
data = data.to(device)
model = model.to(device)
h = model(data.x, data.edge_index)
# Get predictions
if hasattr(model, 'classifier') and model.classifier is not None:
pred = model.classifier(h)
else:
# Initialize classifier if needed
num_classes = len(torch.unique(data.y))
model._init_classifier(num_classes, device)
pred = model.classifier(h)
pred_masked = pred[mask]
target_masked = data.y[mask]
metrics = {
'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
}
except Exception as e:
print(f"Evaluation error: {e}")
metrics = {
'accuracy': 0.0,
'f1_macro': 0.0,
'f1_micro': 0.0,
'error': str(e)
}
return metrics
@staticmethod
def evaluate_graph_classification(model, dataloader, device):
"""Comprehensive graph classification evaluation"""
model.eval()
all_preds = []
all_targets = []
try:
with torch.no_grad():
for batch in dataloader:
batch = batch.to(device)
h = model(batch.x, batch.edge_index, batch.batch)
# Graph-level prediction
graph_h = model.get_graph_embedding(h, batch.batch)
if hasattr(model, 'classifier') and model.classifier is not None:
pred = model.classifier(graph_h)
else:
# Initialize classifier
num_classes = len(torch.unique(batch.y))
model._init_classifier(num_classes, device)
pred = model.classifier(graph_h)
all_preds.append(pred.cpu())
all_targets.append(batch.y.cpu())
if all_preds:
all_preds = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
metrics = {
'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
}
else:
metrics = {'error': 'No predictions generated'}
except Exception as e:
print(f"Graph classification evaluation error: {e}")
metrics = {'error': str(e)}
return metrics |