kfoughali commited on
Commit
8a7e32b
·
verified ·
1 Parent(s): 92e45d6

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +46 -56
utils/metrics.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
4
  import numpy as np
5
 
6
  class GraphMetrics:
7
- """Production-ready evaluation metrics - device safe"""
8
 
9
  @staticmethod
10
  def accuracy(pred, target):
@@ -44,7 +44,40 @@ class GraphMetrics:
44
  return 0.0
45
 
46
  @staticmethod
47
- def evaluate_node_classification(model, data, mask, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """Comprehensive node classification evaluation"""
49
  model.eval()
50
 
@@ -74,59 +107,16 @@ class GraphMetrics:
74
  'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
75
  }
76
 
77
- except Exception as e:
78
- print(f"Evaluation error: {e}")
79
- metrics = {
80
- 'accuracy': 0.0,
81
- 'f1_macro': 0.0,
82
- 'f1_micro': 0.0,
83
- 'error': str(e)
84
- }
85
-
86
- return metrics
87
-
88
- @staticmethod
89
- def evaluate_graph_classification(model, dataloader, device):
90
- """Comprehensive graph classification evaluation"""
91
- model.eval()
92
-
93
- all_preds = []
94
- all_targets = []
95
-
96
- try:
97
- with torch.no_grad():
98
- for batch in dataloader:
99
- batch = batch.to(device)
100
- h = model(batch.x, batch.edge_index, batch.batch)
101
-
102
- # Graph-level prediction
103
- graph_h = model.get_graph_embedding(h, batch.batch)
104
-
105
- if hasattr(model, 'classifier') and model.classifier is not None:
106
- pred = model.classifier(graph_h)
107
- else:
108
- # Initialize classifier
109
- num_classes = len(torch.unique(batch.y))
110
- model._init_classifier(num_classes, device)
111
- pred = model.classifier(graph_h)
112
-
113
- all_preds.append(pred.cpu())
114
- all_targets.append(batch.y.cpu())
115
-
116
- if all_preds:
117
- all_preds = torch.cat(all_preds, dim=0)
118
- all_targets = torch.cat(all_targets, dim=0)
119
 
120
- metrics = {
121
- 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
122
- 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
123
- 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
124
- }
125
- else:
126
- metrics = {'error': 'No predictions generated'}
127
 
128
  except Exception as e:
129
- print(f"Graph classification evaluation error: {e}")
130
- metrics = {'error': str(e)}
131
-
132
- return metrics
 
1
  import torch
2
  import torch.nn.functional as F
3
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
4
  import numpy as np
5
 
6
  class GraphMetrics:
7
+ """Comprehensive evaluation metrics for graph learning"""
8
 
9
  @staticmethod
10
  def accuracy(pred, target):
 
44
  return 0.0
45
 
46
  @staticmethod
47
+ def roc_auc(pred, target):
48
+ """ROC AUC score"""
49
+ try:
50
+ if pred.dim() > 1:
51
+ # Multi-class
52
+ pred_probs = F.softmax(pred, dim=1).cpu().numpy()
53
+ target_onehot = F.one_hot(target, num_classes=pred.size(1)).cpu().numpy()
54
+ return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro')
55
+ else:
56
+ # Binary
57
+ pred_probs = torch.sigmoid(pred).cpu().numpy()
58
+ target_labels = target.cpu().numpy()
59
+ return roc_auc_score(target_labels, pred_probs)
60
+ except:
61
+ return 0.0
62
+
63
+ @staticmethod
64
+ def classification_report_dict(pred, target):
65
+ """Detailed classification report"""
66
+ try:
67
+ if pred.dim() > 1:
68
+ pred_labels = pred.argmax(dim=1)
69
+ else:
70
+ pred_labels = pred
71
+ pred_labels = pred_labels.cpu().numpy()
72
+ target_labels = target.cpu().numpy()
73
+
74
+ report = classification_report(target_labels, pred_labels, output_dict=True, zero_division=0)
75
+ return report
76
+ except:
77
+ return {}
78
+
79
+ @staticmethod
80
+ def evaluate_node_classification(model, data, mask, device, detailed=False):
81
  """Comprehensive node classification evaluation"""
82
  model.eval()
83
 
 
107
  'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
108
  }
109
 
110
+ # Add detailed metrics if requested
111
+ if detailed:
112
+ metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked)
113
+ metrics['classification_report'] = GraphMetrics.classification_report_dict(pred_masked, target_masked)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # Add loss
116
+ criterion = torch.nn.CrossEntropyLoss()
117
+ metrics['loss'] = criterion(pred_masked, target_masked).item()
 
 
 
 
118
 
119
  except Exception as e:
120
+ print(f"Evaluation error: {e}")
121
+ metrics = {
122
+ 'accuracy': 0.0