kfoughali commited on
Commit
8d3a013
·
verified ·
1 Parent(s): 3fb1716

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +76 -378
utils/metrics.py CHANGED
@@ -1,420 +1,118 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from torch_geometric.data import Data
4
- from torch_geometric.transforms import Compose
5
  import numpy as np
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- class GraphProcessor:
11
- """
12
- Advanced data preprocessing utilities
13
- Enterprise-grade with comprehensive validation
14
- """
15
 
16
  @staticmethod
17
- def normalize_features(x, method: str = 'l2'):
18
- """
19
- Normalize node features with validation
20
-
21
- Args:
22
- x: Feature tensor
23
- method: Normalization method ('l2', 'minmax', 'standard')
24
-
25
- Returns:
26
- Normalized feature tensor
27
- """
28
- if not isinstance(x, torch.Tensor):
29
- raise TypeError("x must be a torch.Tensor")
30
- if x.dim() != 2:
31
- raise ValueError("x must be a 2D tensor")
32
-
33
  try:
34
- if method == 'l2':
35
- # L2 normalization with numerical stability
36
- norms = torch.norm(x, p=2, dim=1, keepdim=True)
37
- norms = torch.clamp(norms, min=1e-8) # Avoid division by zero
38
- return x / norms
39
-
40
- elif method == 'minmax':
41
- # Min-max normalization
42
- x_min = x.min(dim=0, keepdim=True)[0]
43
- x_max = x.max(dim=0, keepdim=True)[0]
44
- x_range = x_max - x_min
45
- x_range = torch.clamp(x_range, min=1e-8) # Avoid division by zero
46
- return (x - x_min) / x_range
47
-
48
- elif method == 'standard':
49
- # Standard normalization (z-score)
50
- x_mean = x.mean(dim=0, keepdim=True)
51
- x_std = x.std(dim=0, keepdim=True)
52
- x_std = torch.clamp(x_std, min=1e-8) # Avoid division by zero
53
- return (x - x_mean) / x_std
54
-
55
  else:
56
- logger.warning(f"Unknown normalization method: {method}, returning original")
57
- return x
58
 
59
- except Exception as e:
60
- logger.error(f"Feature normalization failed: {e}")
61
- return x
62
-
63
- @staticmethod
64
- def add_self_loops(edge_index, num_nodes):
65
- """
66
- Add self loops to graph with validation
67
-
68
- Args:
69
- edge_index: Edge connectivity tensor
70
- num_nodes: Number of nodes
71
 
72
- Returns:
73
- Edge index with self loops
74
- """
75
- if not isinstance(edge_index, torch.Tensor):
76
- raise TypeError("edge_index must be a torch.Tensor")
77
- if edge_index.dim() != 2 or edge_index.size(0) != 2:
78
- raise ValueError("edge_index must have shape (2, num_edges)")
79
- if num_nodes <= 0:
80
- raise ValueError("num_nodes must be positive")
81
-
82
- try:
83
- device = edge_index.device
84
- self_loops = torch.arange(num_nodes, device=device).unsqueeze(0).repeat(2, 1)
85
-
86
- # Check if self loops already exist
87
- existing_self_loops = (edge_index[0] == edge_index[1]).any()
88
 
89
- if not existing_self_loops:
90
- edge_index = torch.cat([edge_index, self_loops], dim=1)
91
- logger.debug(f"Added {num_nodes} self loops")
92
-
93
- return edge_index
94
 
95
  except Exception as e:
96
- logger.error(f"Adding self loops failed: {e}")
97
- return edge_index
98
 
99
- @staticmethod
100
- def remove_self_loops(edge_index):
101
- """
102
- Remove self loops from graph
103
-
104
- Args:
105
- edge_index: Edge connectivity tensor
106
-
107
- Returns:
108
- Edge index without self loops
109
- """
110
- if not isinstance(edge_index, torch.Tensor):
111
- raise TypeError("edge_index must be a torch.Tensor")
112
- if edge_index.dim() != 2 or edge_index.size(0) != 2:
113
- raise ValueError("edge_index must have shape (2, num_edges)")
114
-
115
  try:
116
- mask = edge_index[0] != edge_index[1]
117
- filtered_edges = edge_index[:, mask]
118
-
119
- removed_count = edge_index.size(1) - filtered_edges.size(1)
120
- if removed_count > 0:
121
- logger.debug(f"Removed {removed_count} self loops")
 
122
 
123
- return filtered_edges
 
124
 
125
- except Exception as e:
126
- logger.error(f"Removing self loops failed: {e}")
127
- return edge_index
128
-
129
- @staticmethod
130
- def add_positional_features(data, encoding_dim: int = 8):
131
- """
132
- Add positional encodings as features with validation
133
-
134
- Args:
135
- data: PyTorch Geometric data object
136
- encoding_dim: Dimension of positional encoding
137
 
138
- Returns:
139
- Data object with enhanced features
140
- """
141
- if not hasattr(data, 'x') or not hasattr(data, 'edge_index'):
142
- raise ValueError("Data must have x and edge_index attributes")
143
-
144
- num_nodes = data.num_nodes
145
- encoding_dim = max(1, min(encoding_dim, num_nodes))
146
-
147
- try:
148
- # Random walk positional encoding
149
- if data.edge_index.size(1) > 0:
150
- # Create adjacency matrix
151
- adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float)
152
- adj[data.edge_index[0], data.edge_index[1]] = 1.0
153
- adj = adj + adj.t() # Make symmetric
154
-
155
- # Compute degree
156
- degree = adj.sum(dim=1)
157
- degree = torch.clamp(degree, min=1e-8) # Avoid division by zero
158
 
159
- # Degree normalization
160
- D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
161
-
162
- # Normalized adjacency
163
- A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
164
-
165
- # Check for numerical issues
166
- if torch.isnan(A_norm).any() or torch.isinf(A_norm).any():
167
- logger.warning("Numerical issues in adjacency matrix, using simple encoding")
168
- pos_encoding = torch.eye(num_nodes)[:, :encoding_dim]
169
- else:
170
- # Random walk features
171
- rw_features = []
172
- A_power = torch.eye(num_nodes)
173
-
174
- for k in range(encoding_dim):
175
- A_power = A_power @ A_norm
176
- rw_features.append(A_power.diag().unsqueeze(1))
177
-
178
- pos_encoding = torch.cat(rw_features, dim=1)
179
- else:
180
- # No edges - use one-hot encoding
181
- pos_encoding = torch.zeros(num_nodes, encoding_dim)
182
- for i in range(min(encoding_dim, num_nodes)):
183
- pos_encoding[i, i] = 1.0
184
-
185
- # Concatenate with existing features
186
- if data.x is not None:
187
- data.x = torch.cat([data.x, pos_encoding], dim=1)
188
- else:
189
- data.x = pos_encoding
190
-
191
- logger.debug(f"Added positional features of dimension {encoding_dim}")
192
 
193
  except Exception as e:
194
- logger.error(f"Adding positional features failed: {e}")
195
- # Don't modify data on failure
196
- pass
197
-
198
- return data
199
 
200
  @staticmethod
201
- def augment_graph(data, aug_type: str = 'edge_drop', aug_ratio: float = 0.1):
202
- """
203
- Graph augmentation for training with validation
204
-
205
- Args:
206
- data: PyTorch Geometric data object
207
- aug_type: Type of augmentation
208
- aug_ratio: Augmentation strength
209
-
210
- Returns:
211
- Augmented data object
212
- """
213
- if not (0.0 <= aug_ratio <= 1.0):
214
- raise ValueError("aug_ratio must be between 0 and 1")
215
-
216
- # Create a copy to avoid modifying original
217
- aug_data = data.clone()
218
-
219
  try:
220
- if aug_type == 'edge_drop':
221
- # Randomly drop edges
222
- if aug_data.edge_index.size(1) > 0:
223
- num_edges = aug_data.edge_index.size(1)
224
- mask = torch.rand(num_edges) > aug_ratio
225
- aug_data.edge_index = aug_data.edge_index[:, mask]
226
- logger.debug(f"Dropped {(~mask).sum()} edges")
227
-
228
- elif aug_type == 'node_drop':
229
- # Randomly drop nodes
230
- num_nodes = aug_data.num_nodes
231
- if num_nodes > 1:
232
- keep_mask = torch.rand(num_nodes) > aug_ratio
233
-
234
- # Ensure at least one node remains
235
- if not keep_mask.any():
236
- keep_mask[0] = True
237
-
238
- keep_nodes = torch.where(keep_mask)[0]
239
-
240
- # Update node features
241
- if aug_data.x is not None:
242
- aug_data.x = aug_data.x[keep_nodes]
243
-
244
- # Update labels if they exist and are node-level
245
- if hasattr(aug_data, 'y') and aug_data.y.size(0) == num_nodes:
246
- aug_data.y = aug_data.y[keep_nodes]
247
-
248
- # Update edge index
249
- if aug_data.edge_index.size(1) > 0:
250
- # Create node mapping
251
- node_map = torch.full((num_nodes,), -1, dtype=torch.long)
252
- node_map[keep_nodes] = torch.arange(len(keep_nodes))
253
-
254
- # Filter edges
255
- edge_mask = keep_mask[aug_data.edge_index[0]] & keep_mask[aug_data.edge_index[1]]
256
- if edge_mask.any():
257
- filtered_edges = aug_data.edge_index[:, edge_mask]
258
- aug_data.edge_index = node_map[filtered_edges]
259
- else:
260
- aug_data.edge_index = torch.empty((2, 0), dtype=torch.long)
261
-
262
- logger.debug(f"Kept {len(keep_nodes)} out of {num_nodes} nodes")
263
-
264
- elif aug_type == 'feature_noise':
265
- # Add Gaussian noise to features
266
- if aug_data.x is not None:
267
- noise = torch.randn_like(aug_data.x) * aug_ratio
268
- aug_data.x = aug_data.x + noise
269
- logger.debug(f"Added noise with std {aug_ratio}")
270
-
271
- elif aug_type == 'feature_mask':
272
- # Randomly mask features
273
- if aug_data.x is not None:
274
- mask = torch.rand_like(aug_data.x) > aug_ratio
275
- aug_data.x = aug_data.x * mask
276
- logger.debug(f"Masked {(~mask).sum()} feature values")
277
-
278
  else:
279
- logger.warning(f"Unknown augmentation type: {aug_type}")
280
 
281
- except Exception as e:
282
- logger.error(f"Graph augmentation failed: {e}")
283
- return data # Return original on failure
284
-
285
- return aug_data
286
-
287
- @staticmethod
288
- def to_device_safe(data, device):
289
- """
290
- Move data to device safely with validation
291
-
292
- Args:
293
- data: Data to move
294
- device: Target device
295
-
296
- Returns:
297
- Data on target device
298
- """
299
- try:
300
- if hasattr(data, 'to'):
301
- return data.to(device)
302
- elif isinstance(data, (list, tuple)):
303
- return [GraphProcessor.to_device_safe(item, device) for item in data]
304
- elif isinstance(data, dict):
305
- return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
306
- else:
307
- return data
308
- except Exception as e:
309
- logger.error(f"Device transfer failed: {e}")
310
- return data
311
-
312
- @staticmethod
313
- def validate_data(data):
314
- """
315
- Validate graph data integrity with comprehensive checks
316
-
317
- Args:
318
- data: PyTorch Geometric data object
319
-
320
- Returns:
321
- List of validation errors (empty if valid)
322
- """
323
- errors = []
324
-
325
- try:
326
- # Check basic structure
327
- if not hasattr(data, 'edge_index'):
328
- errors.append("Missing edge_index attribute")
329
- elif not isinstance(data.edge_index, torch.Tensor):
330
- errors.append("edge_index must be a tensor")
331
- elif data.edge_index.dim() != 2 or data.edge_index.size(0) != 2:
332
- errors.append("edge_index must have shape (2, num_edges)")
333
 
334
- # Check node features
335
- if hasattr(data, 'x') and data.x is not None:
336
- if not isinstance(data.x, torch.Tensor):
337
- errors.append("Node features x must be a tensor")
338
- elif data.x.dim() != 2:
339
- errors.append("Node features x must be 2D")
340
- elif hasattr(data, 'num_nodes') and data.x.size(0) != data.num_nodes:
341
- errors.append("Feature matrix size mismatch with num_nodes")
342
 
343
- # Check labels
344
- if hasattr(data, 'y') and data.y is not None:
345
- if not isinstance(data.y, torch.Tensor):
346
- errors.append("Labels y must be a tensor")
347
 
348
- # Check edge indices bounds
349
- if hasattr(data, 'edge_index') and data.edge_index.size(1) > 0:
350
- max_idx = data.edge_index.max().item()
351
- min_idx = data.edge_index.min().item()
352
-
353
- if min_idx < 0:
354
- errors.append("Edge indices contain negative values")
355
 
356
- if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
357
- errors.append("Edge indices exceed number of nodes")
358
-
359
- # Check for NaN or infinite values
360
- if hasattr(data, 'x') and data.x is not None:
361
- if torch.isnan(data.x).any():
362
- errors.append("Node features contain NaN values")
363
- if torch.isinf(data.x).any():
364
- errors.append("Node features contain infinite values")
365
-
366
- # Check data types
367
- if hasattr(data, 'edge_index'):
368
- if data.edge_index.dtype not in [torch.long, torch.int]:
369
- errors.append("edge_index must have integer dtype")
370
 
371
  except Exception as e:
372
- errors.append(f"Validation error: {str(e)}")
373
-
374
- if errors:
375
- logger.warning(f"Data validation found {len(errors)} errors")
376
-
377
- return errors
378
-
379
  @staticmethod
380
- def repair_data(data):
381
- """
382
- Attempt to repair common data issues
383
-
384
- Args:
385
- data: PyTorch Geometric data object
386
-
387
- Returns:
388
- Repaired data object
389
- """
390
  try:
391
- # Fix edge index dtype
392
- if hasattr(data, 'edge_index') and data.edge_index.dtype not in [torch.long, torch.int]:
393
- data.edge_index = data.edge_index.long()
394
- logger.info("Fixed edge_index dtype")
395
-
396
- # Remove invalid edges
397
- if hasattr(data, 'edge_index') and hasattr(data, 'num_nodes'):
398
- valid_mask = (
399
- (data.edge_index[0] >= 0) & (data.edge_index[0] < data.num_nodes) &
400
- (data.edge_index[1] >= 0) & (data.edge_index[1] < data.num_nodes)
401
- )
402
 
403
- if not valid_mask.all():
404
- data.edge_index = data.edge_index[:, valid_mask]
405
- logger.info(f"Removed {(~valid_mask).sum()} invalid edges")
 
 
 
 
 
406
 
407
- # Handle NaN values in features
408
- if hasattr(data, 'x') and data.x is not None:
409
- if torch.isnan(data.x).any():
410
- data.x = torch.where(torch.isnan(data.x), torch.zeros_like(data.x), data.x)
411
- logger.info("Replaced NaN values in features with zeros")
412
 
413
- if torch.isinf(data.x).any():
414
- data.x = torch.where(torch.isinf(data.x), torch.zeros_like(data.x), data.x)
415
- logger.info("Replaced infinite values in features with zeros")
416
 
417
  except Exception as e:
418
- logger.error(f"Data repair failed: {e}")
419
-
420
- return data
 
1
  import torch
2
  import torch.nn.functional as F
3
+ from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, precision_score, recall_score
 
4
  import numpy as np
5
  import logging
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
+ class GraphMetrics:
10
+ """Comprehensive evaluation metrics for graph learning"""
 
 
 
11
 
12
  @staticmethod
13
+ def accuracy(pred, target):
14
+ """Classification accuracy with validation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ if pred.dim() > 1:
17
+ pred_labels = pred.argmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
+ pred_labels = pred
 
20
 
21
+ if pred_labels.shape != target.shape:
22
+ raise ValueError("Prediction and target shapes don't match")
 
 
 
 
 
 
 
 
 
 
23
 
24
+ correct = (pred_labels == target).float()
25
+ accuracy = correct.mean().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if torch.isnan(torch.tensor(accuracy)) or torch.isinf(torch.tensor(accuracy)):
28
+ logger.warning("Invalid accuracy computed, returning 0.0")
29
+ return 0.0
30
+
31
+ return accuracy
32
 
33
  except Exception as e:
34
+ logger.error(f"Accuracy computation failed: {e}")
35
+ return 0.0
36
 
37
+ @staticmethod
38
+ def f1_score_macro(pred, target):
39
+ """Macro F1 score with robust error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
+ if pred.dim() > 1:
42
+ pred_labels = pred.argmax(dim=1)
43
+ else:
44
+ pred_labels = pred
45
+
46
+ pred_labels = pred_labels.cpu().numpy()
47
+ target_labels = target.cpu().numpy()
48
 
49
+ if len(pred_labels) == 0 or len(target_labels) == 0:
50
+ return 0.0
51
 
52
+ f1 = f1_score(target_labels, pred_labels, average='macro', zero_division=0)
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ if np.isnan(f1) or np.isinf(f1):
55
+ logger.warning("Invalid F1 macro score, returning 0.0")
56
+ return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ return float(f1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  except Exception as e:
61
+ logger.error(f"F1 macro computation failed: {e}")
62
+ return 0.0
 
 
 
63
 
64
  @staticmethod
65
+ def f1_score_micro(pred, target):
66
+ """Micro F1 score with robust error handling"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ if pred.dim() > 1:
69
+ pred_labels = pred.argmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  else:
71
+ pred_labels = pred
72
 
73
+ pred_labels = pred_labels.cpu().numpy()
74
+ target_labels = target.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if len(pred_labels) == 0 or len(target_labels) == 0:
77
+ return 0.0
 
 
 
 
 
 
78
 
79
+ f1 = f1_score(target_labels, pred_labels, average='micro', zero_division=0)
 
 
 
80
 
81
+ if np.isnan(f1) or np.isinf(f1):
82
+ logger.warning("Invalid F1 micro score, returning 0.0")
83
+ return 0.0
 
 
 
 
84
 
85
+ return float(f1)
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  except Exception as e:
88
+ logger.error(f"F1 micro computation failed: {e}")
89
+ return 0.0
90
+
 
 
 
 
91
  @staticmethod
92
+ def precision_recall(pred, target, average='macro'):
93
+ """Compute precision and recall scores"""
 
 
 
 
 
 
 
 
94
  try:
95
+ if pred.dim() > 1:
96
+ pred_labels = pred.argmax(dim=1)
97
+ else:
98
+ pred_labels = pred
 
 
 
 
 
 
 
99
 
100
+ pred_labels = pred_labels.cpu().numpy()
101
+ target_labels = target.cpu().numpy()
102
+
103
+ if len(pred_labels) == 0 or len(target_labels) == 0:
104
+ return 0.0, 0.0
105
+
106
+ precision = precision_score(target_labels, pred_labels, average=average, zero_division=0)
107
+ recall = recall_score(target_labels, pred_labels, average=average, zero_division=0)
108
 
109
+ if np.isnan(precision) or np.isinf(precision):
110
+ precision = 0.0
111
+ if np.isnan(recall) or np.isinf(recall):
112
+ recall = 0.0
 
113
 
114
+ return float(precision), float(recall)
 
 
115
 
116
  except Exception as e:
117
+ logger.error(f"Precision/recall computation failed: {e}")
118
+ return 0.0, 0.0