kfoughali commited on
Commit
c3b0dee
·
verified ·
1 Parent(s): 97c533b

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +93 -93
utils/metrics.py CHANGED
@@ -119,96 +119,96 @@ class GraphMetrics:
119
  except Exception as e:
120
  print(f"Evaluation error: {e}")
121
  metrics = {
122
- 'accuracy': 0.0
123
- 'f1_macro': 0.0,
124
- 'f1_micro': 0.0,
125
- 'loss': float('inf'),
126
- 'error': str(e)
127
- }
128
-
129
- return metrics
130
-
131
- @staticmethod
132
- def evaluate_graph_classification(model, dataloader, device, detailed=False):
133
- """Comprehensive graph classification evaluation"""
134
- model.eval()
135
-
136
- all_preds = []
137
- all_targets = []
138
- total_loss = 0.0
139
- num_batches = 0
140
-
141
- try:
142
- criterion = torch.nn.CrossEntropyLoss()
143
-
144
- with torch.no_grad():
145
- for batch in dataloader:
146
- batch = batch.to(device)
147
- h = model(batch.x, batch.edge_index, batch.batch)
148
-
149
- # Graph-level prediction
150
- graph_h = model.get_graph_embedding(h, batch.batch)
151
-
152
- if hasattr(model, 'classifier') and model.classifier is not None:
153
- pred = model.classifier(graph_h)
154
- else:
155
- # Initialize classifier
156
- num_classes = len(torch.unique(batch.y))
157
- model._init_classifier(num_classes, device)
158
- pred = model.classifier(graph_h)
159
-
160
- all_preds.append(pred.cpu())
161
- all_targets.append(batch.y.cpu())
162
-
163
- # Calculate loss
164
- loss = criterion(pred, batch.y)
165
- total_loss += loss.item()
166
- num_batches += 1
167
-
168
- if all_preds:
169
- all_preds = torch.cat(all_preds, dim=0)
170
- all_targets = torch.cat(all_targets, dim=0)
171
-
172
- metrics = {
173
- 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
174
- 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
175
- 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
176
- 'loss': total_loss / max(num_batches, 1)
177
- }
178
-
179
- if detailed:
180
- metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
181
- metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
182
- else:
183
- metrics = {'error': 'No predictions generated'}
184
-
185
- except Exception as e:
186
- print(f"Graph classification evaluation error: {e}")
187
- metrics = {
188
- 'accuracy': 0.0,
189
- 'f1_macro': 0.0,
190
- 'f1_micro': 0.0,
191
- 'loss': float('inf'),
192
- 'error': str(e)
193
- }
194
-
195
- return metrics
196
-
197
- @staticmethod
198
- def compare_models(results_dict):
199
- """Compare multiple model results"""
200
- comparison = {}
201
-
202
- for model_name, metrics in results_dict.items():
203
- comparison[model_name] = {
204
- 'accuracy': metrics.get('accuracy', 0.0),
205
- 'f1_macro': metrics.get('f1_macro', 0.0),
206
- 'f1_micro': metrics.get('f1_micro', 0.0),
207
- 'loss': metrics.get('loss', float('inf'))
208
- }
209
-
210
- # Find best performing model
211
- best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
212
- comparison['best_model'] = best_model
213
-
214
- return comparison
 
119
  except Exception as e:
120
  print(f"Evaluation error: {e}")
121
  metrics = {
122
+ 'accuracy': 0.0,
123
+ 'f1_macro': 0.0,
124
+ 'f1_micro': 0.0,
125
+ 'loss': float('inf'),
126
+ 'error': str(e)
127
+ }
128
+
129
+ return metrics
130
+
131
+ @staticmethod
132
+ def evaluate_graph_classification(model, dataloader, device, detailed=False):
133
+ """Comprehensive graph classification evaluation"""
134
+ model.eval()
135
+
136
+ all_preds = []
137
+ all_targets = []
138
+ total_loss = 0.0
139
+ num_batches = 0
140
+
141
+ try:
142
+ criterion = torch.nn.CrossEntropyLoss()
143
+
144
+ with torch.no_grad():
145
+ for batch in dataloader:
146
+ batch = batch.to(device)
147
+ h = model(batch.x, batch.edge_index, batch.batch)
148
+
149
+ # Graph-level prediction
150
+ graph_h = model.get_graph_embedding(h, batch.batch)
151
+
152
+ if hasattr(model, 'classifier') and model.classifier is not None:
153
+ pred = model.classifier(graph_h)
154
+ else:
155
+ # Initialize classifier
156
+ num_classes = len(torch.unique(batch.y))
157
+ model._init_classifier(num_classes, device)
158
+ pred = model.classifier(graph_h)
159
+
160
+ all_preds.append(pred.cpu())
161
+ all_targets.append(batch.y.cpu())
162
+
163
+ # Calculate loss
164
+ loss = criterion(pred, batch.y)
165
+ total_loss += loss.item()
166
+ num_batches += 1
167
+
168
+ if all_preds:
169
+ all_preds = torch.cat(all_preds, dim=0)
170
+ all_targets = torch.cat(all_targets, dim=0)
171
+
172
+ metrics = {
173
+ 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
174
+ 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
175
+ 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
176
+ 'loss': total_loss / max(num_batches, 1)
177
+ }
178
+
179
+ if detailed:
180
+ metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
181
+ metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
182
+ else:
183
+ metrics = {'error': 'No predictions generated'}
184
+
185
+ except Exception as e:
186
+ print(f"Graph classification evaluation error: {e}")
187
+ metrics = {
188
+ 'accuracy': 0.0,
189
+ 'f1_macro': 0.0,
190
+ 'f1_micro': 0.0,
191
+ 'loss': float('inf'),
192
+ 'error': str(e)
193
+ }
194
+
195
+ return metrics
196
+
197
+ @staticmethod
198
+ def compare_models(results_dict):
199
+ """Compare multiple model results"""
200
+ comparison = {}
201
+
202
+ for model_name, metrics in results_dict.items():
203
+ comparison[model_name] = {
204
+ 'accuracy': metrics.get('accuracy', 0.0),
205
+ 'f1_macro': metrics.get('f1_macro', 0.0),
206
+ 'f1_micro': metrics.get('f1_micro', 0.0),
207
+ 'loss': metrics.get('loss', float('inf'))
208
+ }
209
+
210
+ # Find best performing model
211
+ best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
212
+ comparison['best_model'] = best_model
213
+
214
+ return comparison