import torch import torch.nn.functional as F from sklearn.metrics import accuracy_score, f1_score, roc_auc_score import numpy as np class GraphMetrics: """Production-ready evaluation metrics""" @staticmethod def accuracy(pred, target): """Classification accuracy""" pred_labels = pred.argmax(dim=1) return (pred_labels == target).float().mean().item() @staticmethod def f1_score_macro(pred, target): """Macro F1 score""" pred_labels = pred.argmax(dim=1).cpu().numpy() target_labels = target.cpu().numpy() return f1_score(target_labels, pred_labels, average='macro') @staticmethod def f1_score_micro(pred, target): """Micro F1 score""" pred_labels = pred.argmax(dim=1).cpu().numpy() target_labels = target.cpu().numpy() return f1_score(target_labels, pred_labels, average='micro') @staticmethod def roc_auc(pred, target, num_classes): """ROC AUC for multi-class""" if num_classes == 2: # Binary classification pred_probs = F.softmax(pred, dim=1)[:, 1].cpu().numpy() target_labels = target.cpu().numpy() return roc_auc_score(target_labels, pred_probs) else: # Multi-class pred_probs = F.softmax(pred, dim=1).cpu().numpy() target_onehot = F.one_hot(target, num_classes).cpu().numpy() return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro') @staticmethod def evaluate_node_classification(model, data, mask, device): """Comprehensive node classification evaluation""" model.eval() with torch.no_grad(): data = data.to(device) h = model(data.x, data.edge_index) # Assuming a classification head exists if hasattr(model, 'classifier'): pred = model.classifier(h) else: # If no classifier, return embeddings return {'embeddings': h[mask].cpu()} pred_masked = pred[mask] target_masked = data.y[mask] metrics = { 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked), 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked), 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked), } # Add ROC AUC if binary/multi-class try: num_classes = pred.size(1) metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked, num_classes) except: pass return metrics @staticmethod def evaluate_graph_classification(model, dataloader, device): """Comprehensive graph classification evaluation""" model.eval() all_preds = [] all_targets = [] with torch.no_grad(): for batch in dataloader: batch = batch.to(device) h = model(batch.x, batch.edge_index, batch.batch) # Graph-level prediction graph_h = model.get_graph_embedding(h, batch.batch) if hasattr(model, 'classifier'): pred = model.classifier(graph_h) all_preds.append(pred) all_targets.append(batch.y) if all_preds: all_preds = torch.cat(all_preds, dim=0) all_targets = torch.cat(all_targets, dim=0) metrics = { 'accuracy': GraphMetrics.accuracy(all_preds, all_targets), 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets), 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets), } try: num_classes = all_preds.size(1) metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets, num_classes) except: pass return metrics return {'error': 'No predictions generated'}