Update utils/metrics.py
Browse files- utils/metrics.py +93 -1
utils/metrics.py
CHANGED
|
@@ -119,4 +119,96 @@ class GraphMetrics:
|
|
| 119 |
except Exception as e:
|
| 120 |
print(f"Evaluation error: {e}")
|
| 121 |
metrics = {
|
| 122 |
-
'accuracy': 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
except Exception as e:
|
| 120 |
print(f"Evaluation error: {e}")
|
| 121 |
metrics = {
|
| 122 |
+
'accuracy': 0.0
|
| 123 |
+
'f1_macro': 0.0,
|
| 124 |
+
'f1_micro': 0.0,
|
| 125 |
+
'loss': float('inf'),
|
| 126 |
+
'error': str(e)
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
return metrics
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def evaluate_graph_classification(model, dataloader, device, detailed=False):
|
| 133 |
+
"""Comprehensive graph classification evaluation"""
|
| 134 |
+
model.eval()
|
| 135 |
+
|
| 136 |
+
all_preds = []
|
| 137 |
+
all_targets = []
|
| 138 |
+
total_loss = 0.0
|
| 139 |
+
num_batches = 0
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 143 |
+
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for batch in dataloader:
|
| 146 |
+
batch = batch.to(device)
|
| 147 |
+
h = model(batch.x, batch.edge_index, batch.batch)
|
| 148 |
+
|
| 149 |
+
# Graph-level prediction
|
| 150 |
+
graph_h = model.get_graph_embedding(h, batch.batch)
|
| 151 |
+
|
| 152 |
+
if hasattr(model, 'classifier') and model.classifier is not None:
|
| 153 |
+
pred = model.classifier(graph_h)
|
| 154 |
+
else:
|
| 155 |
+
# Initialize classifier
|
| 156 |
+
num_classes = len(torch.unique(batch.y))
|
| 157 |
+
model._init_classifier(num_classes, device)
|
| 158 |
+
pred = model.classifier(graph_h)
|
| 159 |
+
|
| 160 |
+
all_preds.append(pred.cpu())
|
| 161 |
+
all_targets.append(batch.y.cpu())
|
| 162 |
+
|
| 163 |
+
# Calculate loss
|
| 164 |
+
loss = criterion(pred, batch.y)
|
| 165 |
+
total_loss += loss.item()
|
| 166 |
+
num_batches += 1
|
| 167 |
+
|
| 168 |
+
if all_preds:
|
| 169 |
+
all_preds = torch.cat(all_preds, dim=0)
|
| 170 |
+
all_targets = torch.cat(all_targets, dim=0)
|
| 171 |
+
|
| 172 |
+
metrics = {
|
| 173 |
+
'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
|
| 174 |
+
'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
|
| 175 |
+
'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
|
| 176 |
+
'loss': total_loss / max(num_batches, 1)
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
if detailed:
|
| 180 |
+
metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
|
| 181 |
+
metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
|
| 182 |
+
else:
|
| 183 |
+
metrics = {'error': 'No predictions generated'}
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Graph classification evaluation error: {e}")
|
| 187 |
+
metrics = {
|
| 188 |
+
'accuracy': 0.0,
|
| 189 |
+
'f1_macro': 0.0,
|
| 190 |
+
'f1_micro': 0.0,
|
| 191 |
+
'loss': float('inf'),
|
| 192 |
+
'error': str(e)
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return metrics
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def compare_models(results_dict):
|
| 199 |
+
"""Compare multiple model results"""
|
| 200 |
+
comparison = {}
|
| 201 |
+
|
| 202 |
+
for model_name, metrics in results_dict.items():
|
| 203 |
+
comparison[model_name] = {
|
| 204 |
+
'accuracy': metrics.get('accuracy', 0.0),
|
| 205 |
+
'f1_macro': metrics.get('f1_macro', 0.0),
|
| 206 |
+
'f1_micro': metrics.get('f1_micro', 0.0),
|
| 207 |
+
'loss': metrics.get('loss', float('inf'))
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# Find best performing model
|
| 211 |
+
best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
|
| 212 |
+
comparison['best_model'] = best_model
|
| 213 |
+
|
| 214 |
+
return comparison
|