File size: 4,858 Bytes
6d0498a
 
 
 
 
 
850d736
6d0498a
 
 
 
850d736
 
 
 
6d0498a
 
 
 
 
850d736
 
 
 
 
 
 
 
 
 
6d0498a
 
 
 
850d736
 
 
 
 
 
6d0498a
850d736
 
 
6d0498a
 
 
 
 
 
850d736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d0498a
850d736
 
 
 
6d0498a
 
 
 
 
 
 
 
 
 
 
 
850d736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d0498a
850d736
 
 
 
 
 
 
 
 
 
 
6d0498a
850d736
 
 
6d0498a
850d736
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np

class GraphMetrics:
    """Production-ready evaluation metrics - device safe"""
    
    @staticmethod
    def accuracy(pred, target):
        """Classification accuracy"""
        if pred.dim() > 1:
            pred_labels = pred.argmax(dim=1)
        else:
            pred_labels = pred
        return (pred_labels == target).float().mean().item()
    
    @staticmethod 
    def f1_score_macro(pred, target):
        """Macro F1 score"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
            pred_labels = pred_labels.cpu().numpy()
            target_labels = target.cpu().numpy()
            return f1_score(target_labels, pred_labels, average='macro', zero_division=0)
        except:
            return 0.0
    
    @staticmethod
    def f1_score_micro(pred, target):
        """Micro F1 score"""
        try:
            if pred.dim() > 1:
                pred_labels = pred.argmax(dim=1)
            else:
                pred_labels = pred
            pred_labels = pred_labels.cpu().numpy() 
            target_labels = target.cpu().numpy()
            return f1_score(target_labels, pred_labels, average='micro', zero_division=0)
        except:
            return 0.0
    
    @staticmethod
    def evaluate_node_classification(model, data, mask, device):
        """Comprehensive node classification evaluation"""
        model.eval()
        
        try:
            with torch.no_grad():
                # Ensure data is on correct device
                data = data.to(device)
                model = model.to(device)
                
                h = model(data.x, data.edge_index)
                
                # Get predictions
                if hasattr(model, 'classifier') and model.classifier is not None:
                    pred = model.classifier(h)
                else:
                    # Initialize classifier if needed
                    num_classes = len(torch.unique(data.y))
                    model._init_classifier(num_classes, device)
                    pred = model.classifier(h)
                
                pred_masked = pred[mask]
                target_masked = data.y[mask]
                
                metrics = {
                    'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
                    'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
                    'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
                }
                
        except Exception as e:
            print(f"Evaluation error: {e}")
            metrics = {
                'accuracy': 0.0,
                'f1_macro': 0.0,
                'f1_micro': 0.0,
                'error': str(e)
            }
                
        return metrics
    
    @staticmethod
    def evaluate_graph_classification(model, dataloader, device):
        """Comprehensive graph classification evaluation"""
        model.eval()
        
        all_preds = []
        all_targets = []
        
        try:
            with torch.no_grad():
                for batch in dataloader:
                    batch = batch.to(device)
                    h = model(batch.x, batch.edge_index, batch.batch)
                    
                    # Graph-level prediction
                    graph_h = model.get_graph_embedding(h, batch.batch)
                    
                    if hasattr(model, 'classifier') and model.classifier is not None:
                        pred = model.classifier(graph_h)
                    else:
                        # Initialize classifier
                        num_classes = len(torch.unique(batch.y))
                        model._init_classifier(num_classes, device)
                        pred = model.classifier(graph_h)
                    
                    all_preds.append(pred.cpu())
                    all_targets.append(batch.y.cpu())
            
            if all_preds:
                all_preds = torch.cat(all_preds, dim=0)
                all_targets = torch.cat(all_targets, dim=0)
                
                metrics = {
                    'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
                    'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
                    'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
                }
            else:
                metrics = {'error': 'No predictions generated'}
                
        except Exception as e:
            print(f"Graph classification evaluation error: {e}")
            metrics = {'error': str(e)}
        
        return metrics