import torch import torch.nn.functional as F from sklearn.metrics import accuracy_score, f1_score, roc_auc_score import numpy as np class GraphMetrics: """Production-ready evaluation metrics - device safe""" @staticmethod def accuracy(pred, target): """Classification accuracy""" if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred return (pred_labels == target).float().mean().item() @staticmethod def f1_score_macro(pred, target): """Macro F1 score""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred pred_labels = pred_labels.cpu().numpy() target_labels = target.cpu().numpy() return f1_score(target_labels, pred_labels, average='macro', zero_division=0) except: return 0.0 @staticmethod def f1_score_micro(pred, target): """Micro F1 score""" try: if pred.dim() > 1: pred_labels = pred.argmax(dim=1) else: pred_labels = pred pred_labels = pred_labels.cpu().numpy() target_labels = target.cpu().numpy() return f1_score(target_labels, pred_labels, average='micro', zero_division=0) except: return 0.0 @staticmethod def evaluate_node_classification(model, data, mask, device): """Comprehensive node classification evaluation""" model.eval() try: with torch.no_grad(): # Ensure data is on correct device data = data.to(device) model = model.to(device) h = model(data.x, data.edge_index) # Get predictions if hasattr(model, 'classifier') and model.classifier is not None: pred = model.classifier(h) else: # Initialize classifier if needed num_classes = len(torch.unique(data.y)) model._init_classifier(num_classes, device) pred = model.classifier(h) pred_masked = pred[mask] target_masked = data.y[mask] metrics = { 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked), 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked), 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked), } except Exception as e: print(f"Evaluation error: {e}") metrics = { 'accuracy': 0.0, 'f1_macro': 0.0, 'f1_micro': 0.0, 'error': str(e) } return metrics @staticmethod def evaluate_graph_classification(model, dataloader, device): """Comprehensive graph classification evaluation""" model.eval() all_preds = [] all_targets = [] try: with torch.no_grad(): for batch in dataloader: batch = batch.to(device) h = model(batch.x, batch.edge_index, batch.batch) # Graph-level prediction graph_h = model.get_graph_embedding(h, batch.batch) if hasattr(model, 'classifier') and model.classifier is not None: pred = model.classifier(graph_h) else: # Initialize classifier num_classes = len(torch.unique(batch.y)) model._init_classifier(num_classes, device) pred = model.classifier(graph_h) all_preds.append(pred.cpu()) all_targets.append(batch.y.cpu()) if all_preds: all_preds = torch.cat(all_preds, dim=0) all_targets = torch.cat(all_targets, dim=0) metrics = { 'accuracy': GraphMetrics.accuracy(all_preds, all_targets), 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets), 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets), } else: metrics = {'error': 'No predictions generated'} except Exception as e: print(f"Graph classification evaluation error: {e}") metrics = {'error': str(e)} return metrics