kfoughali commited on
Commit
272c30c
·
verified ·
1 Parent(s): 8a7e32b

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +93 -1
utils/metrics.py CHANGED
@@ -119,4 +119,96 @@ class GraphMetrics:
119
  except Exception as e:
120
  print(f"Evaluation error: {e}")
121
  metrics = {
122
- 'accuracy': 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  except Exception as e:
120
  print(f"Evaluation error: {e}")
121
  metrics = {
122
+ 'accuracy': 0.0
123
+ 'f1_macro': 0.0,
124
+ 'f1_micro': 0.0,
125
+ 'loss': float('inf'),
126
+ 'error': str(e)
127
+ }
128
+
129
+ return metrics
130
+
131
+ @staticmethod
132
+ def evaluate_graph_classification(model, dataloader, device, detailed=False):
133
+ """Comprehensive graph classification evaluation"""
134
+ model.eval()
135
+
136
+ all_preds = []
137
+ all_targets = []
138
+ total_loss = 0.0
139
+ num_batches = 0
140
+
141
+ try:
142
+ criterion = torch.nn.CrossEntropyLoss()
143
+
144
+ with torch.no_grad():
145
+ for batch in dataloader:
146
+ batch = batch.to(device)
147
+ h = model(batch.x, batch.edge_index, batch.batch)
148
+
149
+ # Graph-level prediction
150
+ graph_h = model.get_graph_embedding(h, batch.batch)
151
+
152
+ if hasattr(model, 'classifier') and model.classifier is not None:
153
+ pred = model.classifier(graph_h)
154
+ else:
155
+ # Initialize classifier
156
+ num_classes = len(torch.unique(batch.y))
157
+ model._init_classifier(num_classes, device)
158
+ pred = model.classifier(graph_h)
159
+
160
+ all_preds.append(pred.cpu())
161
+ all_targets.append(batch.y.cpu())
162
+
163
+ # Calculate loss
164
+ loss = criterion(pred, batch.y)
165
+ total_loss += loss.item()
166
+ num_batches += 1
167
+
168
+ if all_preds:
169
+ all_preds = torch.cat(all_preds, dim=0)
170
+ all_targets = torch.cat(all_targets, dim=0)
171
+
172
+ metrics = {
173
+ 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
174
+ 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
175
+ 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
176
+ 'loss': total_loss / max(num_batches, 1)
177
+ }
178
+
179
+ if detailed:
180
+ metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
181
+ metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
182
+ else:
183
+ metrics = {'error': 'No predictions generated'}
184
+
185
+ except Exception as e:
186
+ print(f"Graph classification evaluation error: {e}")
187
+ metrics = {
188
+ 'accuracy': 0.0,
189
+ 'f1_macro': 0.0,
190
+ 'f1_micro': 0.0,
191
+ 'loss': float('inf'),
192
+ 'error': str(e)
193
+ }
194
+
195
+ return metrics
196
+
197
+ @staticmethod
198
+ def compare_models(results_dict):
199
+ """Compare multiple model results"""
200
+ comparison = {}
201
+
202
+ for model_name, metrics in results_dict.items():
203
+ comparison[model_name] = {
204
+ 'accuracy': metrics.get('accuracy', 0.0),
205
+ 'f1_macro': metrics.get('f1_macro', 0.0),
206
+ 'f1_micro': metrics.get('f1_micro', 0.0),
207
+ 'loss': metrics.get('loss', float('inf'))
208
+ }
209
+
210
+ # Find best performing model
211
+ best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
212
+ comparison['best_model'] = best_model
213
+
214
+ return comparison