kfoughali commited on
Commit
972fdf4
·
verified ·
1 Parent(s): c3b0dee

Update utils/metrics.py

Browse files
Files changed (1) hide show
  1. utils/metrics.py +382 -176
utils/metrics.py CHANGED
@@ -1,214 +1,420 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
 
4
  import numpy as np
 
5
 
6
- class GraphMetrics:
7
- """Comprehensive evaluation metrics for graph learning"""
 
 
 
 
 
8
 
9
  @staticmethod
10
- def accuracy(pred, target):
11
- """Classification accuracy"""
12
- if pred.dim() > 1:
13
- pred_labels = pred.argmax(dim=1)
14
- else:
15
- pred_labels = pred
16
- return (pred_labels == target).float().mean().item()
17
-
18
- @staticmethod
19
- def f1_score_macro(pred, target):
20
- """Macro F1 score"""
 
 
 
 
 
21
  try:
22
- if pred.dim() > 1:
23
- pred_labels = pred.argmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
- pred_labels = pred
26
- pred_labels = pred_labels.cpu().numpy()
27
- target_labels = target.cpu().numpy()
28
- return f1_score(target_labels, pred_labels, average='macro', zero_division=0)
29
- except:
30
- return 0.0
31
 
32
  @staticmethod
33
- def f1_score_micro(pred, target):
34
- """Micro F1 score"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- if pred.dim() > 1:
37
- pred_labels = pred.argmax(dim=1)
38
- else:
39
- pred_labels = pred
40
- pred_labels = pred_labels.cpu().numpy()
41
- target_labels = target.cpu().numpy()
42
- return f1_score(target_labels, pred_labels, average='micro', zero_division=0)
43
- except:
44
- return 0.0
 
 
 
 
 
 
45
 
46
  @staticmethod
47
- def roc_auc(pred, target):
48
- """ROC AUC score"""
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  try:
50
- if pred.dim() > 1:
51
- # Multi-class
52
- pred_probs = F.softmax(pred, dim=1).cpu().numpy()
53
- target_onehot = F.one_hot(target, num_classes=pred.size(1)).cpu().numpy()
54
- return roc_auc_score(target_onehot, pred_probs, multi_class='ovr', average='macro')
55
- else:
56
- # Binary
57
- pred_probs = torch.sigmoid(pred).cpu().numpy()
58
- target_labels = target.cpu().numpy()
59
- return roc_auc_score(target_labels, pred_probs)
60
- except:
61
- return 0.0
62
 
63
  @staticmethod
64
- def classification_report_dict(pred, target):
65
- """Detailed classification report"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
- if pred.dim() > 1:
68
- pred_labels = pred.argmax(dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  else:
70
- pred_labels = pred
71
- pred_labels = pred_labels.cpu().numpy()
72
- target_labels = target.cpu().numpy()
73
-
74
- report = classification_report(target_labels, pred_labels, output_dict=True, zero_division=0)
75
- return report
76
- except:
77
- return {}
 
 
78
 
79
  @staticmethod
80
- def evaluate_node_classification(model, data, mask, device, detailed=False):
81
- """Comprehensive node classification evaluation"""
82
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  try:
85
- with torch.no_grad():
86
- # Ensure data is on correct device
87
- data = data.to(device)
88
- model = model.to(device)
89
-
90
- h = model(data.x, data.edge_index)
 
91
 
92
- # Get predictions
93
- if hasattr(model, 'classifier') and model.classifier is not None:
94
- pred = model.classifier(h)
95
- else:
96
- # Initialize classifier if needed
97
- num_classes = len(torch.unique(data.y))
98
- model._init_classifier(num_classes, device)
99
- pred = model.classifier(h)
100
-
101
- pred_masked = pred[mask]
102
- target_masked = data.y[mask]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- metrics = {
105
- 'accuracy': GraphMetrics.accuracy(pred_masked, target_masked),
106
- 'f1_macro': GraphMetrics.f1_score_macro(pred_masked, target_masked),
107
- 'f1_micro': GraphMetrics.f1_score_micro(pred_masked, target_masked),
108
- }
 
109
 
110
- # Add detailed metrics if requested
111
- if detailed:
112
- metrics['roc_auc'] = GraphMetrics.roc_auc(pred_masked, target_masked)
113
- metrics['classification_report'] = GraphMetrics.classification_report_dict(pred_masked, target_masked)
 
 
114
 
115
- # Add loss
116
- criterion = torch.nn.CrossEntropyLoss()
117
- metrics['loss'] = criterion(pred_masked, target_masked).item()
118
 
119
  except Exception as e:
120
- print(f"Evaluation error: {e}")
121
- metrics = {
122
- 'accuracy': 0.0,
123
- 'f1_macro': 0.0,
124
- 'f1_micro': 0.0,
125
- 'loss': float('inf'),
126
- 'error': str(e)
127
- }
128
-
129
- return metrics
130
 
131
  @staticmethod
132
- def evaluate_graph_classification(model, dataloader, device, detailed=False):
133
- """Comprehensive graph classification evaluation"""
134
- model.eval()
135
 
136
- all_preds = []
137
- all_targets = []
138
- total_loss = 0.0
139
- num_batches = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  try:
142
- criterion = torch.nn.CrossEntropyLoss()
 
 
 
 
 
 
143
 
144
- with torch.no_grad():
145
- for batch in dataloader:
146
- batch = batch.to(device)
147
- h = model(batch.x, batch.edge_index, batch.batch)
148
-
149
- # Graph-level prediction
150
- graph_h = model.get_graph_embedding(h, batch.batch)
151
-
152
- if hasattr(model, 'classifier') and model.classifier is not None:
153
- pred = model.classifier(graph_h)
154
- else:
155
- # Initialize classifier
156
- num_classes = len(torch.unique(batch.y))
157
- model._init_classifier(num_classes, device)
158
- pred = model.classifier(graph_h)
159
-
160
- all_preds.append(pred.cpu())
161
- all_targets.append(batch.y.cpu())
162
-
163
- # Calculate loss
164
- loss = criterion(pred, batch.y)
165
- total_loss += loss.item()
166
- num_batches += 1
167
-
168
- if all_preds:
169
- all_preds = torch.cat(all_preds, dim=0)
170
- all_targets = torch.cat(all_targets, dim=0)
171
-
172
- metrics = {
173
- 'accuracy': GraphMetrics.accuracy(all_preds, all_targets),
174
- 'f1_macro': GraphMetrics.f1_score_macro(all_preds, all_targets),
175
- 'f1_micro': GraphMetrics.f1_score_micro(all_preds, all_targets),
176
- 'loss': total_loss / max(num_batches, 1)
177
- }
178
-
179
- if detailed:
180
- metrics['roc_auc'] = GraphMetrics.roc_auc(all_preds, all_targets)
181
- metrics['classification_report'] = GraphMetrics.classification_report_dict(all_preds, all_targets)
182
- else:
183
- metrics = {'error': 'No predictions generated'}
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  except Exception as e:
186
- print(f"Graph classification evaluation error: {e}")
187
- metrics = {
188
- 'accuracy': 0.0,
189
- 'f1_macro': 0.0,
190
- 'f1_micro': 0.0,
191
- 'loss': float('inf'),
192
- 'error': str(e)
193
- }
194
-
195
- return metrics
196
 
197
  @staticmethod
198
- def compare_models(results_dict):
199
- """Compare multiple model results"""
200
- comparison = {}
201
-
202
- for model_name, metrics in results_dict.items():
203
- comparison[model_name] = {
204
- 'accuracy': metrics.get('accuracy', 0.0),
205
- 'f1_macro': metrics.get('f1_macro', 0.0),
206
- 'f1_micro': metrics.get('f1_micro', 0.0),
207
- 'loss': metrics.get('loss', float('inf'))
208
- }
209
-
210
- # Find best performing model
211
- best_model = max(comparison.keys(), key=lambda k: comparison[k]['accuracy'])
212
- comparison['best_model'] = best_model
213
-
214
- return comparison
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
+ from torch_geometric.data import Data
4
+ from torch_geometric.transforms import Compose
5
  import numpy as np
6
+ import logging
7
 
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class GraphProcessor:
11
+ """
12
+ Advanced data preprocessing utilities
13
+ Enterprise-grade with comprehensive validation
14
+ """
15
 
16
  @staticmethod
17
+ def normalize_features(x, method: str = 'l2'):
18
+ """
19
+ Normalize node features with validation
20
+
21
+ Args:
22
+ x: Feature tensor
23
+ method: Normalization method ('l2', 'minmax', 'standard')
24
+
25
+ Returns:
26
+ Normalized feature tensor
27
+ """
28
+ if not isinstance(x, torch.Tensor):
29
+ raise TypeError("x must be a torch.Tensor")
30
+ if x.dim() != 2:
31
+ raise ValueError("x must be a 2D tensor")
32
+
33
  try:
34
+ if method == 'l2':
35
+ # L2 normalization with numerical stability
36
+ norms = torch.norm(x, p=2, dim=1, keepdim=True)
37
+ norms = torch.clamp(norms, min=1e-8) # Avoid division by zero
38
+ return x / norms
39
+
40
+ elif method == 'minmax':
41
+ # Min-max normalization
42
+ x_min = x.min(dim=0, keepdim=True)[0]
43
+ x_max = x.max(dim=0, keepdim=True)[0]
44
+ x_range = x_max - x_min
45
+ x_range = torch.clamp(x_range, min=1e-8) # Avoid division by zero
46
+ return (x - x_min) / x_range
47
+
48
+ elif method == 'standard':
49
+ # Standard normalization (z-score)
50
+ x_mean = x.mean(dim=0, keepdim=True)
51
+ x_std = x.std(dim=0, keepdim=True)
52
+ x_std = torch.clamp(x_std, min=1e-8) # Avoid division by zero
53
+ return (x - x_mean) / x_std
54
+
55
  else:
56
+ logger.warning(f"Unknown normalization method: {method}, returning original")
57
+ return x
58
+
59
+ except Exception as e:
60
+ logger.error(f"Feature normalization failed: {e}")
61
+ return x
62
 
63
  @staticmethod
64
+ def add_self_loops(edge_index, num_nodes):
65
+ """
66
+ Add self loops to graph with validation
67
+
68
+ Args:
69
+ edge_index: Edge connectivity tensor
70
+ num_nodes: Number of nodes
71
+
72
+ Returns:
73
+ Edge index with self loops
74
+ """
75
+ if not isinstance(edge_index, torch.Tensor):
76
+ raise TypeError("edge_index must be a torch.Tensor")
77
+ if edge_index.dim() != 2 or edge_index.size(0) != 2:
78
+ raise ValueError("edge_index must have shape (2, num_edges)")
79
+ if num_nodes <= 0:
80
+ raise ValueError("num_nodes must be positive")
81
+
82
  try:
83
+ device = edge_index.device
84
+ self_loops = torch.arange(num_nodes, device=device).unsqueeze(0).repeat(2, 1)
85
+
86
+ # Check if self loops already exist
87
+ existing_self_loops = (edge_index[0] == edge_index[1]).any()
88
+
89
+ if not existing_self_loops:
90
+ edge_index = torch.cat([edge_index, self_loops], dim=1)
91
+ logger.debug(f"Added {num_nodes} self loops")
92
+
93
+ return edge_index
94
+
95
+ except Exception as e:
96
+ logger.error(f"Adding self loops failed: {e}")
97
+ return edge_index
98
 
99
  @staticmethod
100
+ def remove_self_loops(edge_index):
101
+ """
102
+ Remove self loops from graph
103
+
104
+ Args:
105
+ edge_index: Edge connectivity tensor
106
+
107
+ Returns:
108
+ Edge index without self loops
109
+ """
110
+ if not isinstance(edge_index, torch.Tensor):
111
+ raise TypeError("edge_index must be a torch.Tensor")
112
+ if edge_index.dim() != 2 or edge_index.size(0) != 2:
113
+ raise ValueError("edge_index must have shape (2, num_edges)")
114
+
115
  try:
116
+ mask = edge_index[0] != edge_index[1]
117
+ filtered_edges = edge_index[:, mask]
118
+
119
+ removed_count = edge_index.size(1) - filtered_edges.size(1)
120
+ if removed_count > 0:
121
+ logger.debug(f"Removed {removed_count} self loops")
122
+
123
+ return filtered_edges
124
+
125
+ except Exception as e:
126
+ logger.error(f"Removing self loops failed: {e}")
127
+ return edge_index
128
 
129
  @staticmethod
130
+ def add_positional_features(data, encoding_dim: int = 8):
131
+ """
132
+ Add positional encodings as features with validation
133
+
134
+ Args:
135
+ data: PyTorch Geometric data object
136
+ encoding_dim: Dimension of positional encoding
137
+
138
+ Returns:
139
+ Data object with enhanced features
140
+ """
141
+ if not hasattr(data, 'x') or not hasattr(data, 'edge_index'):
142
+ raise ValueError("Data must have x and edge_index attributes")
143
+
144
+ num_nodes = data.num_nodes
145
+ encoding_dim = max(1, min(encoding_dim, num_nodes))
146
+
147
  try:
148
+ # Random walk positional encoding
149
+ if data.edge_index.size(1) > 0:
150
+ # Create adjacency matrix
151
+ adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float)
152
+ adj[data.edge_index[0], data.edge_index[1]] = 1.0
153
+ adj = adj + adj.t() # Make symmetric
154
+
155
+ # Compute degree
156
+ degree = adj.sum(dim=1)
157
+ degree = torch.clamp(degree, min=1e-8) # Avoid division by zero
158
+
159
+ # Degree normalization
160
+ D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
161
+
162
+ # Normalized adjacency
163
+ A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
164
+
165
+ # Check for numerical issues
166
+ if torch.isnan(A_norm).any() or torch.isinf(A_norm).any():
167
+ logger.warning("Numerical issues in adjacency matrix, using simple encoding")
168
+ pos_encoding = torch.eye(num_nodes)[:, :encoding_dim]
169
+ else:
170
+ # Random walk features
171
+ rw_features = []
172
+ A_power = torch.eye(num_nodes)
173
+
174
+ for k in range(encoding_dim):
175
+ A_power = A_power @ A_norm
176
+ rw_features.append(A_power.diag().unsqueeze(1))
177
+
178
+ pos_encoding = torch.cat(rw_features, dim=1)
179
+ else:
180
+ # No edges - use one-hot encoding
181
+ pos_encoding = torch.zeros(num_nodes, encoding_dim)
182
+ for i in range(min(encoding_dim, num_nodes)):
183
+ pos_encoding[i, i] = 1.0
184
+
185
+ # Concatenate with existing features
186
+ if data.x is not None:
187
+ data.x = torch.cat([data.x, pos_encoding], dim=1)
188
  else:
189
+ data.x = pos_encoding
190
+
191
+ logger.debug(f"Added positional features of dimension {encoding_dim}")
192
+
193
+ except Exception as e:
194
+ logger.error(f"Adding positional features failed: {e}")
195
+ # Don't modify data on failure
196
+ pass
197
+
198
+ return data
199
 
200
  @staticmethod
201
+ def augment_graph(data, aug_type: str = 'edge_drop', aug_ratio: float = 0.1):
202
+ """
203
+ Graph augmentation for training with validation
204
+
205
+ Args:
206
+ data: PyTorch Geometric data object
207
+ aug_type: Type of augmentation
208
+ aug_ratio: Augmentation strength
209
+
210
+ Returns:
211
+ Augmented data object
212
+ """
213
+ if not (0.0 <= aug_ratio <= 1.0):
214
+ raise ValueError("aug_ratio must be between 0 and 1")
215
+
216
+ # Create a copy to avoid modifying original
217
+ aug_data = data.clone()
218
 
219
  try:
220
+ if aug_type == 'edge_drop':
221
+ # Randomly drop edges
222
+ if aug_data.edge_index.size(1) > 0:
223
+ num_edges = aug_data.edge_index.size(1)
224
+ mask = torch.rand(num_edges) > aug_ratio
225
+ aug_data.edge_index = aug_data.edge_index[:, mask]
226
+ logger.debug(f"Dropped {(~mask).sum()} edges")
227
 
228
+ elif aug_type == 'node_drop':
229
+ # Randomly drop nodes
230
+ num_nodes = aug_data.num_nodes
231
+ if num_nodes > 1:
232
+ keep_mask = torch.rand(num_nodes) > aug_ratio
233
+
234
+ # Ensure at least one node remains
235
+ if not keep_mask.any():
236
+ keep_mask[0] = True
237
+
238
+ keep_nodes = torch.where(keep_mask)[0]
239
+
240
+ # Update node features
241
+ if aug_data.x is not None:
242
+ aug_data.x = aug_data.x[keep_nodes]
243
+
244
+ # Update labels if they exist and are node-level
245
+ if hasattr(aug_data, 'y') and aug_data.y.size(0) == num_nodes:
246
+ aug_data.y = aug_data.y[keep_nodes]
247
+
248
+ # Update edge index
249
+ if aug_data.edge_index.size(1) > 0:
250
+ # Create node mapping
251
+ node_map = torch.full((num_nodes,), -1, dtype=torch.long)
252
+ node_map[keep_nodes] = torch.arange(len(keep_nodes))
253
+
254
+ # Filter edges
255
+ edge_mask = keep_mask[aug_data.edge_index[0]] & keep_mask[aug_data.edge_index[1]]
256
+ if edge_mask.any():
257
+ filtered_edges = aug_data.edge_index[:, edge_mask]
258
+ aug_data.edge_index = node_map[filtered_edges]
259
+ else:
260
+ aug_data.edge_index = torch.empty((2, 0), dtype=torch.long)
261
+
262
+ logger.debug(f"Kept {len(keep_nodes)} out of {num_nodes} nodes")
263
 
264
+ elif aug_type == 'feature_noise':
265
+ # Add Gaussian noise to features
266
+ if aug_data.x is not None:
267
+ noise = torch.randn_like(aug_data.x) * aug_ratio
268
+ aug_data.x = aug_data.x + noise
269
+ logger.debug(f"Added noise with std {aug_ratio}")
270
 
271
+ elif aug_type == 'feature_mask':
272
+ # Randomly mask features
273
+ if aug_data.x is not None:
274
+ mask = torch.rand_like(aug_data.x) > aug_ratio
275
+ aug_data.x = aug_data.x * mask
276
+ logger.debug(f"Masked {(~mask).sum()} feature values")
277
 
278
+ else:
279
+ logger.warning(f"Unknown augmentation type: {aug_type}")
 
280
 
281
  except Exception as e:
282
+ logger.error(f"Graph augmentation failed: {e}")
283
+ return data # Return original on failure
284
+
285
+ return aug_data
 
 
 
 
 
 
286
 
287
  @staticmethod
288
+ def to_device_safe(data, device):
289
+ """
290
+ Move data to device safely with validation
291
 
292
+ Args:
293
+ data: Data to move
294
+ device: Target device
295
+
296
+ Returns:
297
+ Data on target device
298
+ """
299
+ try:
300
+ if hasattr(data, 'to'):
301
+ return data.to(device)
302
+ elif isinstance(data, (list, tuple)):
303
+ return [GraphProcessor.to_device_safe(item, device) for item in data]
304
+ elif isinstance(data, dict):
305
+ return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
306
+ else:
307
+ return data
308
+ except Exception as e:
309
+ logger.error(f"Device transfer failed: {e}")
310
+ return data
311
+
312
+ @staticmethod
313
+ def validate_data(data):
314
+ """
315
+ Validate graph data integrity with comprehensive checks
316
+
317
+ Args:
318
+ data: PyTorch Geometric data object
319
+
320
+ Returns:
321
+ List of validation errors (empty if valid)
322
+ """
323
+ errors = []
324
 
325
  try:
326
+ # Check basic structure
327
+ if not hasattr(data, 'edge_index'):
328
+ errors.append("Missing edge_index attribute")
329
+ elif not isinstance(data.edge_index, torch.Tensor):
330
+ errors.append("edge_index must be a tensor")
331
+ elif data.edge_index.dim() != 2 or data.edge_index.size(0) != 2:
332
+ errors.append("edge_index must have shape (2, num_edges)")
333
 
334
+ # Check node features
335
+ if hasattr(data, 'x') and data.x is not None:
336
+ if not isinstance(data.x, torch.Tensor):
337
+ errors.append("Node features x must be a tensor")
338
+ elif data.x.dim() != 2:
339
+ errors.append("Node features x must be 2D")
340
+ elif hasattr(data, 'num_nodes') and data.x.size(0) != data.num_nodes:
341
+ errors.append("Feature matrix size mismatch with num_nodes")
342
+
343
+ # Check labels
344
+ if hasattr(data, 'y') and data.y is not None:
345
+ if not isinstance(data.y, torch.Tensor):
346
+ errors.append("Labels y must be a tensor")
347
+
348
+ # Check edge indices bounds
349
+ if hasattr(data, 'edge_index') and data.edge_index.size(1) > 0:
350
+ max_idx = data.edge_index.max().item()
351
+ min_idx = data.edge_index.min().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ if min_idx < 0:
354
+ errors.append("Edge indices contain negative values")
355
+
356
+ if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
357
+ errors.append("Edge indices exceed number of nodes")
358
+
359
+ # Check for NaN or infinite values
360
+ if hasattr(data, 'x') and data.x is not None:
361
+ if torch.isnan(data.x).any():
362
+ errors.append("Node features contain NaN values")
363
+ if torch.isinf(data.x).any():
364
+ errors.append("Node features contain infinite values")
365
+
366
+ # Check data types
367
+ if hasattr(data, 'edge_index'):
368
+ if data.edge_index.dtype not in [torch.long, torch.int]:
369
+ errors.append("edge_index must have integer dtype")
370
+
371
  except Exception as e:
372
+ errors.append(f"Validation error: {str(e)}")
373
+
374
+ if errors:
375
+ logger.warning(f"Data validation found {len(errors)} errors")
376
+
377
+ return errors
 
 
 
 
378
 
379
  @staticmethod
380
+ def repair_data(data):
381
+ """
382
+ Attempt to repair common data issues
383
+
384
+ Args:
385
+ data: PyTorch Geometric data object
386
+
387
+ Returns:
388
+ Repaired data object
389
+ """
390
+ try:
391
+ # Fix edge index dtype
392
+ if hasattr(data, 'edge_index') and data.edge_index.dtype not in [torch.long, torch.int]:
393
+ data.edge_index = data.edge_index.long()
394
+ logger.info("Fixed edge_index dtype")
395
+
396
+ # Remove invalid edges
397
+ if hasattr(data, 'edge_index') and hasattr(data, 'num_nodes'):
398
+ valid_mask = (
399
+ (data.edge_index[0] >= 0) & (data.edge_index[0] < data.num_nodes) &
400
+ (data.edge_index[1] >= 0) & (data.edge_index[1] < data.num_nodes)
401
+ )
402
+
403
+ if not valid_mask.all():
404
+ data.edge_index = data.edge_index[:, valid_mask]
405
+ logger.info(f"Removed {(~valid_mask).sum()} invalid edges")
406
+
407
+ # Handle NaN values in features
408
+ if hasattr(data, 'x') and data.x is not None:
409
+ if torch.isnan(data.x).any():
410
+ data.x = torch.where(torch.isnan(data.x), torch.zeros_like(data.x), data.x)
411
+ logger.info("Replaced NaN values in features with zeros")
412
+
413
+ if torch.isinf(data.x).any():
414
+ data.x = torch.where(torch.isinf(data.x), torch.zeros_like(data.x), data.x)
415
+ logger.info("Replaced infinite values in features with zeros")
416
+
417
+ except Exception as e:
418
+ logger.error(f"Data repair failed: {e}")
419
+
420
+ return data