kfoughali commited on
Commit
6d0498a
·
verified ·
1 Parent(s): abceea1

Create utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +116 -0
utils/metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
4
+ import numpy as np
5
+
6
+ class GraphMetrics:
7
+ """Production-ready evaluation metrics"""
8
+
9
+ @staticmethod
10
+ def accuracy(pred, target):
11
+ """Classification accuracy"""
12
+ pred_labels = pred.argmax(dim=1)
13
+ return (pred_labels == target).float().mean().item()
14
+
15
+ @staticmethod
16
+ def f1_score_macro(pred, target):
17
+ """Macro F1 score"""
18
+ pred_labels = pred.argmax(dim=1).cpu().numpy()
19
+ target_labels = target.cpu().numpy()
20
+ return f1_score(target_labels, pred_labels, average='macro')
21
+
22
+ @staticmethod
23
+ def f1_score_micro(pred, target):
24
+ """Micro F1 score"""
25
+ pred_labels = pred.argmax(dim=1).cpu().numpy()
26
+ target_labels = target.cpu().numpy()
27
+ return f1_score(target_labels, pred_labels, average='micro')
28
+
29
+ @staticmethod
30
+ def roc_auc(pred, target, num_classes):
31
+ """ROC AUC for multi-class"""
32
+ if num_classes == 2:
33
+ # Binary classification
34
+ pred_probs = F.softmax(pred, dim=1)[:, 1].cpu().numpy()
35
+ target_labels = target.cpu().numpy()
36
+ return roc_auc_score(target_labels, pred_probs)
37
+ else:
38
+ # Multi-class
39
+ pred_probs = F.softmax(pred, dim=1).cpu().numpy()
40
+ target_onehot = F.one_hot(target, num_classes).cpu().numpy()
41
+ return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro')
42
+
43
+ @staticmethod
44
+ def evaluate_node_classification(model, data, mask, device):
45
+ """Comprehensive node classification evaluation"""
46
+ model.eval()
47
+
48
+ with torch.no_grad():
49
+ data = data.to(device)
50
+ h = model(data.x, data.edge_index)
51
+
52
+ # Assuming a classification head exists
53
+ if hasattr(model, 'classifier'):
54
+ pred = model.classifier(h)
55
+ else:
56
+ # If no classifier, return embeddings
57
+ return {'embeddings': h[mask].cpu()}
58
+
59
+ pred_masked = pred[mask]
60
+ target_masked = data.y[mask]
61
+
62
+ metrics = {
63
+ 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
64
+ 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
65
+ 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
66
+ }
67
+
68
+ # Add ROC AUC if binary/multi-class
69
+ try:
70
+ num_classes = pred.size(1)
71
+ metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked, num_classes)
72
+ except:
73
+ pass
74
+
75
+ return metrics
76
+
77
+ @staticmethod
78
+ def evaluate_graph_classification(model, dataloader, device):
79
+ """Comprehensive graph classification evaluation"""
80
+ model.eval()
81
+
82
+ all_preds = []
83
+ all_targets = []
84
+
85
+ with torch.no_grad():
86
+ for batch in dataloader:
87
+ batch = batch.to(device)
88
+ h = model(batch.x, batch.edge_index, batch.batch)
89
+
90
+ # Graph-level prediction
91
+ graph_h = model.get_graph_embedding(h, batch.batch)
92
+
93
+ if hasattr(model, 'classifier'):
94
+ pred = model.classifier(graph_h)
95
+ all_preds.append(pred)
96
+ all_targets.append(batch.y)
97
+
98
+ if all_preds:
99
+ all_preds = torch.cat(all_preds, dim=0)
100
+ all_targets = torch.cat(all_targets, dim=0)
101
+
102
+ metrics = {
103
+ 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
104
+ 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
105
+ 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
106
+ }
107
+
108
+ try:
109
+ num_classes = all_preds.size(1)
110
+ metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets, num_classes)
111
+ except:
112
+ pass
113
+
114
+ return metrics
115
+
116
+ return {'error': 'No predictions generated'}