Update utils/metrics.py
Browse files- utils/metrics.py +93 -93
utils/metrics.py
CHANGED
@@ -119,96 +119,96 @@ class GraphMetrics:
|
|
119 |
except Exception as e:
|
120 |
print(f"Evaluation error: {e}")
|
121 |
metrics = {
|
122 |
-
'accuracy': 0.0
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
119 |
except Exception as e:
|
120 |
print(f"Evaluation error: {e}")
|
121 |
metrics = {
|
122 |
+
'accuracy': 0.0,
|
123 |
+
'f1_macro': 0.0,
|
124 |
+
'f1_micro': 0.0,
|
125 |
+
'loss': float('inf'),
|
126 |
+
'error': str(e)
|
127 |
+
}
|
128 |
+
|
129 |
+
return metrics
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def evaluate_graph_classification(model, dataloader, device, detailed=False):
|
133 |
+
"""Comprehensive graph classification evaluation"""
|
134 |
+
model.eval()
|
135 |
+
|
136 |
+
all_preds = []
|
137 |
+
all_targets = []
|
138 |
+
total_loss = 0.0
|
139 |
+
num_batches = 0
|
140 |
+
|
141 |
+
try:
|
142 |
+
criterion = torch.nn.CrossEntropyLoss()
|
143 |
+
|
144 |
+
with torch.no_grad():
|
145 |
+
for batch in dataloader:
|
146 |
+
batch = batch.to(device)
|
147 |
+
h = model(batch.x, batch.edge_index, batch.batch)
|
148 |
+
|
149 |
+
# Graph-level prediction
|
150 |
+
graph_h = model.get_graph_embedding(h, batch.batch)
|
151 |
+
|
152 |
+
if hasattr(model, 'classifier') and model.classifier is not None:
|
153 |
+
pred = model.classifier(graph_h)
|
154 |
+
else:
|
155 |
+
# Initialize classifier
|
156 |
+
num_classes = len(torch.unique(batch.y))
|
157 |
+
model._init_classifier(num_classes, device)
|
158 |
+
pred = model.classifier(graph_h)
|
159 |
+
|
160 |
+
all_preds.append(pred.cpu())
|
161 |
+
all_targets.append(batch.y.cpu())
|
162 |
+
|
163 |
+
# Calculate loss
|
164 |
+
loss = criterion(pred, batch.y)
|
165 |
+
total_loss += loss.item()
|
166 |
+
num_batches += 1
|
167 |
+
|
168 |
+
if all_preds:
|
169 |
+
all_preds = torch.cat(all_preds, dim=0)
|
170 |
+
all_targets = torch.cat(all_targets, dim=0)
|
171 |
+
|
172 |
+
metrics = {
|
173 |
+
'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
|
174 |
+
'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
|
175 |
+
'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
|
176 |
+
'loss': total_loss / max(num_batches, 1)
|
177 |
+
}
|
178 |
+
|
179 |
+
if detailed:
|
180 |
+
metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
|
181 |
+
metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
|
182 |
+
else:
|
183 |
+
metrics = {'error': 'No predictions generated'}
|
184 |
+
|
185 |
+
except Exception as e:
|
186 |
+
print(f"Graph classification evaluation error: {e}")
|
187 |
+
metrics = {
|
188 |
+
'accuracy': 0.0,
|
189 |
+
'f1_macro': 0.0,
|
190 |
+
'f1_micro': 0.0,
|
191 |
+
'loss': float('inf'),
|
192 |
+
'error': str(e)
|
193 |
+
}
|
194 |
+
|
195 |
+
return metrics
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def compare_models(results_dict):
|
199 |
+
"""Compare multiple model results"""
|
200 |
+
comparison = {}
|
201 |
+
|
202 |
+
for model_name, metrics in results_dict.items():
|
203 |
+
comparison[model_name] = {
|
204 |
+
'accuracy': metrics.get('accuracy', 0.0),
|
205 |
+
'f1_macro': metrics.get('f1_macro', 0.0),
|
206 |
+
'f1_micro': metrics.get('f1_micro', 0.0),
|
207 |
+
'loss': metrics.get('loss', float('inf'))
|
208 |
+
}
|
209 |
+
|
210 |
+
# Find best performing model
|
211 |
+
best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
|
212 |
+
comparison['best_model'] = best_model
|
213 |
+
|
214 |
+
return comparison
|