kfoughali commited on
Commit
ed6db0b
·
verified ·
1 Parent(s): 5f6f777

Update utils/visualization.py

Browse files
Files changed (1) hide show
  1. utils/visualization.py +223 -674
utils/visualization.py CHANGED
@@ -11,31 +11,15 @@ import logging
11
  logger = logging.getLogger(__name__)
12
 
13
  class GraphVisualizer:
14
- """
15
- Advanced graph visualization utilities
16
- Enterprise-grade with comprehensive error handling
17
- """
18
 
19
  @staticmethod
20
  def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0):
21
- """
22
- Create interactive graph visualization with robust error handling
23
-
24
- Args:
25
- data: PyTorch Geometric data object
26
- max_nodes: Maximum number of nodes to display
27
- layout_algorithm: Layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')
28
- node_size_factor: Factor to scale node sizes
29
-
30
- Returns:
31
- Plotly figure object
32
- """
33
  try:
34
- # Validate inputs
35
  if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'):
36
  raise ValueError("Data must have edge_index and num_nodes attributes")
37
 
38
- # Limit nodes for performance
39
  num_nodes = min(data.num_nodes, max_nodes)
40
  if num_nodes <= 0:
41
  raise ValueError("No nodes to visualize")
@@ -45,8 +29,6 @@ class GraphVisualizer:
45
 
46
  if data.edge_index.size(1) > 0:
47
  edge_list = data.edge_index.t().cpu().numpy()
48
-
49
- # Filter edges to include only first max_nodes
50
  edge_list = edge_list[
51
  (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
52
  ]
@@ -54,25 +36,30 @@ class GraphVisualizer:
54
  if len(edge_list) > 0:
55
  G.add_edges_from(edge_list)
56
 
57
- # Add isolated nodes
58
  G.add_nodes_from(range(num_nodes))
59
 
60
- # Choose layout algorithm with fallback
61
- pos = GraphVisualizer._compute_layout(G, layout_algorithm)
62
 
63
- # Node colors and sizes
64
- node_colors, color_map = GraphVisualizer._get_node_colors(data, num_nodes)
65
- node_sizes = GraphVisualizer._get_node_sizes(G, node_size_factor)
 
 
66
 
67
  # Create edge traces
68
- edge_x, edge_y, edge_info = GraphVisualizer._create_edge_traces(G, pos)
 
 
 
 
 
 
69
 
70
  # Create node traces
71
- node_x, node_y, node_text, node_info = GraphVisualizer._create_node_traces(
72
- G, pos, node_colors, data
73
- )
74
 
75
- # Create figure
76
  fig = go.Figure()
77
 
78
  # Add edges
@@ -82,7 +69,6 @@ class GraphVisualizer:
82
  line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
83
  hoverinfo='none',
84
  mode='lines',
85
- name='Edges',
86
  showlegend=False
87
  ))
88
 
@@ -90,41 +76,23 @@ class GraphVisualizer:
90
  fig.add_trace(go.Scatter(
91
  x=node_x, y=node_y,
92
  mode='markers',
93
- hoverinfo='text',
94
- hovertext=node_info,
95
- text=node_text,
96
  marker=dict(
97
- size=node_sizes,
98
  color=node_colors,
99
  colorscale='Viridis',
100
  line=dict(width=2, color='white'),
101
- opacity=0.8,
102
- colorbar=dict(title="Node Label") if hasattr(data, 'y') and data.y is not None else None
103
  ),
104
- name='Nodes',
 
105
  showlegend=False
106
  ))
107
 
108
- # Update layout
109
  fig.update_layout(
110
- title=dict(
111
- text=f'Graph Visualization ({num_nodes} nodes, {len(edge_x)//3 if edge_x else 0} edges)',
112
- x=0.5,
113
- font=dict(size=16)
114
- ),
115
  showlegend=False,
116
  hovermode='closest',
117
  margin=dict(b=20, l=5, r=5, t=40),
118
- annotations=[
119
- dict(
120
- text=f"Layout: {layout_algorithm.title()}",
121
- showarrow=False,
122
- xref="paper", yref="paper",
123
- x=0.005, y=-0.002,
124
- xanchor='left', yanchor='bottom',
125
- font=dict(color="gray", size=10)
126
- )
127
- ],
128
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
129
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
130
  plot_bgcolor='white',
@@ -138,126 +106,10 @@ class GraphVisualizer:
138
  logger.error(f"Graph visualization error: {e}")
139
  return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}")
140
 
141
- @staticmethod
142
- def _compute_layout(G, algorithm):
143
- """Compute graph layout with fallback options"""
144
- try:
145
- if algorithm == 'spring':
146
- if len(G.nodes()) > 100:
147
- return nx.spring_layout(G, k=0.5, iterations=20, seed=42)
148
- else:
149
- return nx.spring_layout(G, k=1, iterations=50, seed=42)
150
- elif algorithm == 'circular':
151
- return nx.circular_layout(G)
152
- elif algorithm == 'kamada_kawai':
153
- if len(G.nodes()) <= 500: # Too slow for large graphs
154
- return nx.kamada_kawai_layout(G)
155
- else:
156
- raise ValueError("Too many nodes for Kamada-Kawai")
157
- elif algorithm == 'spectral':
158
- if len(G.edges()) > 0:
159
- return nx.spectral_layout(G)
160
- else:
161
- return nx.circular_layout(G)
162
- else:
163
- return nx.spring_layout(G, seed=42)
164
- except Exception as e:
165
- logger.warning(f"Layout algorithm {algorithm} failed: {e}, using spring layout")
166
- return nx.spring_layout(G, seed=42)
167
-
168
- @staticmethod
169
- def _get_node_colors(data, num_nodes):
170
- """Get node colors based on labels"""
171
- try:
172
- if hasattr(data, 'y') and data.y is not None:
173
- node_colors = data.y.cpu().numpy()[:num_nodes]
174
- unique_labels = np.unique(node_colors)
175
- color_map = px.colors.qualitative.Set3[:len(unique_labels)]
176
- else:
177
- node_colors = [0] * num_nodes
178
- color_map = ['lightblue']
179
-
180
- return node_colors, color_map
181
- except Exception as e:
182
- logger.warning(f"Node color computation failed: {e}")
183
- return [0] * num_nodes, ['lightblue']
184
-
185
- @staticmethod
186
- def _get_node_sizes(G, size_factor):
187
- """Get node sizes based on degree"""
188
- try:
189
- node_sizes = []
190
- for node in G.nodes():
191
- degree = G.degree(node)
192
- size = max(5, min(20, 5 + degree * 2)) * size_factor
193
- node_sizes.append(size)
194
- return node_sizes
195
- except Exception as e:
196
- logger.warning(f"Node size computation failed: {e}")
197
- return [8] * len(G.nodes())
198
-
199
- @staticmethod
200
- def _create_edge_traces(G, pos):
201
- """Create edge traces for plotting"""
202
- try:
203
- edge_x, edge_y, edge_info = [], [], []
204
-
205
- for edge in G.edges():
206
- if edge[0] in pos and edge[1] in pos:
207
- x0, y0 = pos[edge[0]]
208
- x1, y1 = pos[edge[1]]
209
- edge_x.extend([x0, x1, None])
210
- edge_y.extend([y0, y1, None])
211
- edge_info.append(f"Edge: {edge[0]} - {edge[1]}")
212
-
213
- return edge_x, edge_y, edge_info
214
- except Exception as e:
215
- logger.warning(f"Edge trace creation failed: {e}")
216
- return [], [], []
217
-
218
- @staticmethod
219
- def _create_node_traces(G, pos, node_colors, data):
220
- """Create node traces for plotting"""
221
- try:
222
- node_x, node_y, node_text, node_info = [], [], [], []
223
-
224
- for node in G.nodes():
225
- if node in pos:
226
- x, y = pos[node]
227
- node_x.append(x)
228
- node_y.append(y)
229
-
230
- # Node info
231
- degree = G.degree(node)
232
- label = node_colors[node] if node < len(node_colors) else 0
233
- node_text.append(f"Node {node}")
234
-
235
- # Enhanced node info
236
- info = f"Node: {node}<br>Degree: {degree}<br>Label: {label}"
237
- if hasattr(data, 'x') and data.x is not None and node < data.x.size(0):
238
- feature_sum = data.x[node].sum().item()
239
- info += f"<br>Feature Sum: {feature_sum:.2f}"
240
-
241
- node_info.append(info)
242
-
243
- return node_x, node_y, node_text, node_info
244
- except Exception as e:
245
- logger.warning(f"Node trace creation failed: {e}")
246
- return [], [], [], []
247
-
248
  @staticmethod
249
  def create_metrics_plot(metrics):
250
- """
251
- Create comprehensive metrics visualization
252
-
253
- Args:
254
- metrics: Dictionary of metrics
255
-
256
- Returns:
257
- Plotly figure object
258
- """
259
  try:
260
- # Filter and validate numeric metrics
261
  metric_names = []
262
  metric_values = []
263
 
@@ -270,505 +122,202 @@ class GraphVisualizer:
270
  if not metric_names:
271
  return GraphVisualizer._create_error_figure("No valid metrics to display")
272
 
273
- # Create subplots
274
- fig = make_subpl
275
- # Create subplots
276
- fig = make_subplots(
277
- rows=1, cols=2,
278
- subplot_titles=('Performance Metrics', 'Metric Radar Chart'),
279
- specs=[[{"type": "bar"}, {"type": "polar"}]]
280
- )
281
-
282
- # Bar chart
283
- colors = px.colors.qualitative.Set3[:len(metric_names)]
284
-
285
- fig.add_trace(
286
- go.Bar(
287
- x=metric_names,
288
- y=metric_values,
289
- marker_color=colors,
290
- text=[f'{v:.3f}' for v in metric_values],
291
- textposition='auto',
292
- name='Metrics',
293
- showlegend=False
294
- ),
295
- row=1, col=1
296
- )
297
-
298
- # Radar chart
299
- fig.add_trace(
300
- go.Scatterpolar(
301
- r=metric_values + [metric_values[0]], # Close the polygon
302
- theta=metric_names + [metric_names[0]],
303
- fill='toself',
304
- name='Performance',
305
- line=dict(color='blue'),
306
- marker=dict(size=8),
307
- showlegend=False
308
- ),
309
- row=1, col=2
310
- )
311
-
312
- fig.update_layout(
313
- title=dict(
314
- text='Model Performance Dashboard',
315
- x=0.5,
316
- font=dict(size=18)
317
- ),
318
- height=400,
319
- showlegend=False
320
- )
321
-
322
- # Update bar chart axes
323
- fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
324
- fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
325
-
326
- # Update polar chart
327
- fig.update_polars(
328
- radialaxis=dict(range=[0, 1], showticklabels=True, tickfont=dict(size=10)),
329
- angularaxis=dict(tickfont=dict(size=10)),
330
- row=1, col=2
331
- )
332
-
333
- return fig
334
-
335
- except Exception as e:
336
- logger.error(f"Metrics plot error: {e}")
337
- return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}")
338
-
339
- @staticmethod
340
- def create_training_history_plot(history):
341
- """
342
- Create comprehensive training history visualization
343
-
344
- Args:
345
- history: Training history dictionary
346
-
347
- Returns:
348
- Plotly figure object
349
- """
350
- try:
351
- if not isinstance(history, dict) or not history:
352
- return GraphVisualizer._create_error_figure("No training history available")
353
-
354
- # Validate history data
355
- required_keys = ['train_loss', 'train_acc']
356
- for key in required_keys:
357
- if key not in history or not history[key]:
358
- return GraphVisualizer._create_error_figure(f"Missing {key} in training history")
359
-
360
- epochs = list(range(len(history['train_loss'])))
361
-
362
- # Create subplots
363
- fig = make_subplots(
364
- rows=2, cols=2,
365
- subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'),
366
- specs=[[{"secondary_y": False}, {"secondary_y": False}],
367
- [{"secondary_y": False}, {"secondary_y": False}]]
368
- )
369
-
370
- # Training loss
371
- fig.add_trace(
372
- go.Scatter(
373
- x=epochs, y=history['train_loss'],
374
- mode='lines', name='Train Loss',
375
- line=dict(color='blue', width=2),
376
- showlegend=False
377
- ),
378
- row=1, col=1
379
- )
380
-
381
- if 'val_loss' in history and history['val_loss']:
382
- fig.add_trace(
383
- go.Scatter(
384
- x=epochs, y=history['val_loss'],
385
- mode='lines', name='Val Loss',
386
- line=dict(color='red', width=2),
387
- showlegend=False
388
- ),
389
- row=1, col=1
390
- )
391
-
392
- # Training accuracy
393
- fig.add_trace(
394
- go.Scatter(
395
- x=epochs, y=history['train_acc'],
396
- mode='lines', name='Train Acc',
397
- line=dict(color='green', width=2),
398
- showlegend=False
399
- ),
400
- row=1, col=2
401
- )
402
-
403
- if 'val_acc' in history and history['val_acc']:
404
- fig.add_trace(
405
- go.Scatter(
406
- x=epochs, y=history['val_acc'],
407
- mode='lines', name='Val Acc',
408
- line=dict(color='orange', width=2),
409
- showlegend=False
410
- ),
411
- row=1, col=2
412
- )
413
-
414
- # Learning rate
415
- if 'lr' in history and history['lr']:
416
- fig.add_trace(
417
- go.Scatter(
418
- x=epochs, y=history['lr'],
419
- mode='lines', name='Learning Rate',
420
- line=dict(color='purple', width=2),
421
- showlegend=False
422
- ),
423
- row=2, col=1
424
- )
425
-
426
- # Training progress summary
427
- final_metrics = {
428
- 'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0,
429
- 'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0,
430
- }
431
-
432
- if 'val_acc' in history and history['val_acc']:
433
- final_metrics['Final Val Acc'] = history['val_acc'][-1]
434
- final_metrics['Best Val Acc'] = max(history['val_acc'])
435
-
436
- metric_names = list(final_metrics.keys())
437
- metric_values = list(final_metrics.values())
438
-
439
- fig.add_trace(
440
- go.Bar(
441
- x=metric_names,
442
- y=metric_values,
443
- marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'],
444
- text=[f'{v:.3f}' for v in metric_values],
445
- textposition='auto',
446
- showlegend=False
447
- ),
448
- row=2, col=2
449
- )
450
-
451
- fig.update_layout(
452
- title=dict(
453
- text='Training History Dashboard',
454
- x=0.5,
455
- font=dict(size=18)
456
- ),
457
- height=600,
458
- showlegend=True
459
- )
460
-
461
- # Update axes
462
- fig.update_xaxes(title_text="Epoch", row=1, col=1)
463
- fig.update_xaxes(title_text="Epoch", row=1, col=2)
464
- fig.update_xaxes(title_text="Epoch", row=2, col=1)
465
- fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2)
466
-
467
- fig.update_yaxes(title_text="Loss", row=1, col=1)
468
- fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2)
469
- fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
470
- fig.update_yaxes(title_text="Value", row=2, col=2)
471
-
472
- return fig
473
-
474
- except Exception as e:
475
- logger.error(f"Training history plot error: {e}")
476
- return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}")
477
-
478
- @staticmethod
479
- def create_dataset_stats_plot(dataset_info):
480
- """
481
- Create dataset statistics visualization
482
-
483
- Args:
484
- dataset_info: Dictionary containing dataset statistics
485
-
486
- Returns:
487
- Plotly figure object
488
- """
489
- try:
490
- if not isinstance(dataset_info, dict) or not dataset_info:
491
- return GraphVisualizer._create_error_figure("No dataset information available")
492
-
493
- # Prepare data
494
- stats_data = []
495
- for key, value in dataset_info.items():
496
- if isinstance(value, (int, float)) and not np.isnan(value) and np.isfinite(value):
497
- stats_data.append({
498
- 'Metric': key.replace('_', ' ').title(),
499
- 'Value': value
500
- })
501
-
502
- if not stats_data:
503
- return GraphVisualizer._create_error_figure("No valid statistics to display")
504
-
505
- df = pd.DataFrame(stats_data)
506
-
507
- # Create subplots
508
- fig = make_subplots(
509
- rows=2, cols=2,
510
- subplot_titles=('Dataset Overview', 'Graph Size Distribution', 'Feature Statistics', 'Connectivity'),
511
- specs=[[{"type": "bar"}, {"type": "histogram"}],
512
- [{"type": "bar"}, {"type": "bar"}]]
513
- )
514
-
515
- # Main statistics bar chart
516
- main_metrics = ['Num Features', 'Num Classes', 'Num Graphs', 'Total Nodes', 'Total Edges']
517
- main_data = df[df['Metric'].isin(main_metrics)]
518
-
519
- if not main_data.empty:
520
- fig.add_trace(
521
- go.Bar(
522
- x=main_data['Metric'],
523
- y=main_data['Value'],
524
- marker_color=px.colors.qualitative.Pastel1[:len(main_data)],
525
- text=[f'{int(v):,}' if v >= 1 else f'{v:.2f}' for v in main_data['Value']],
526
- textposition='auto',
527
- showlegend=False
528
- ),
529
- row=1, col=1
530
- )
531
-
532
- # Graph size distribution
533
- if 'num_graphs' in dataset_info and dataset_info['num_graphs'] > 1:
534
- # Simulate distribution based on min/max/avg
535
- avg_nodes = dataset_info.get('avg_nodes', 100)
536
- min_nodes = dataset_info.get('min_nodes', avg_nodes * 0.5)
537
- max_nodes = dataset_info.get('max_nodes', avg_nodes * 1.5)
538
-
539
- # Generate realistic distribution
540
- np.random.seed(42)
541
- if max_nodes > min_nodes:
542
- node_dist = np.random.lognormal(
543
- mean=np.log(avg_nodes),
544
- sigma=0.5,
545
- size=min(100, int(dataset_info['num_graphs']))
546
- )
547
- node_dist = np.clip(node_dist, min_nodes, max_nodes)
548
- else:
549
- node_dist = [avg_nodes] * min(100, int(dataset_info['num_graphs']))
550
-
551
- fig.add_trace(
552
- go.Histogram(
553
- x=node_dist,
554
- nbinsx=20,
555
- marker_color='lightblue',
556
- opacity=0.7,
557
- showlegend=False
558
- ),
559
- row=1, col=2
560
- )
561
- else:
562
- # Single graph - show as point
563
- fig.add_trace(
564
- go.Scatter(
565
- x=['Nodes'],
566
- y=[dataset_info.get('avg_nodes', 0)],
567
- mode='markers',
568
- marker=dict(size=20, color='blue'),
569
- showlegend=False
570
- ),
571
- row=1, col=2
572
- )
573
-
574
- # Feature statistics
575
- feature_metrics = ['Avg Nodes', 'Avg Edges', 'Avg Degree']
576
- feature_data = df[df['Metric'].isin(feature_metrics)]
577
-
578
- if not feature_data.empty:
579
- fig.add_trace(
580
- go.Bar(
581
- x=feature_data['Metric'],
582
- y=feature_data['Value'],
583
- marker_color=['lightgreen', 'lightcoral', 'lightyellow'],
584
- text=[f'{v:.1f}' for v in feature_data['Value']],
585
- textposition='auto',
586
- showlegend=False
587
- ),
588
- row=2, col=1
589
- )
590
-
591
- # Connectivity analysis
592
- connectivity_data = []
593
- if 'total_nodes' in dataset_info and 'total_edges' in dataset_info:
594
- total_nodes = dataset_info['total_nodes']
595
- total_edges = dataset_info['total_edges']
596
-
597
- if total_nodes > 0:
598
- max_possible_edges = total_nodes * (total_nodes - 1) / 2
599
- density = total_edges / max_possible_edges if max_possible_edges > 0 else 0
600
- avg_degree = dataset_info.get('avg_degree', 0)
601
-
602
- connectivity_data = [
603
- {'Metric': 'Graph Density', 'Value': density},
604
- {'Metric': 'Avg Degree', 'Value': avg_degree / total_nodes if total_nodes > 0 else 0},
605
- {'Metric': 'Edge Ratio', 'Value': total_edges / total_nodes if total_nodes > 0 else 0}
606
- ]
607
-
608
- if connectivity_data:
609
- conn_df = pd.DataFrame(connectivity_data)
610
- fig.add_trace(
611
- go.Bar(
612
- x=conn_df['Metric'],
613
- y=conn_df['Value'],
614
- marker_color=['lightpink', 'lightsteelblue', 'lightgoldenrodyellow'],
615
- text=[f'{v:.3f}' for v in conn_df['Value']],
616
- textposition='auto',
617
- showlegend=False
618
- ),
619
- row=2, col=2
620
- )
621
-
622
- fig.update_layout(
623
- title=dict(
624
- text='Dataset Statistics Dashboard',
625
- x=0.5,
626
- font=dict(size=16)
627
- ),
628
- height=600,
629
- showlegend=False
630
- )
631
-
632
- # Update axes
633
- fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
634
- fig.update_xaxes(title_text="Number of Nodes", row=1, col=2)
635
- fig.update_xaxes(title_text="Statistics", tickangle=45, row=2, col=1)
636
- fig.update_xaxes(title_text="Connectivity Metrics", tickangle=45, row=2, col=2)
637
-
638
- fig.update_yaxes(title_text="Count", row=1, col=1)
639
- fig.update_yaxes(title_text="Frequency", row=1, col=2)
640
- fig.update_yaxes(title_text="Average Value", row=2, col=1)
641
- fig.update_yaxes(title_text="Ratio", row=2, col=2)
642
-
643
- return fig
644
-
645
- except Exception as e:
646
- logger.error(f"Dataset stats plot error: {e}")
647
- return GraphVisualizer._create_error_figure(f"Dataset stats error: {str(e)}")
648
-
649
- @staticmethod
650
- def _create_error_figure(error_message):
651
- """Create an error figure with message"""
652
- fig = go.Figure()
653
- fig.add_annotation(
654
- text=error_message,
655
- x=0.5, y=0.5,
656
- xref="paper", yref="paper",
657
- showarrow=False,
658
- font=dict(size=14, color="red"),
659
- bgcolor="rgba(255,255,255,0.8)",
660
- bordercolor="red",
661
- borderwidth=1
662
- )
663
- fig.update_layout(
664
- title="Visualization Error",
665
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
666
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
667
- plot_bgcolor='white',
668
- width=600,
669
- height=400
670
- )
671
- return fig
672
-
673
- @staticmethod
674
- def create_comparison_plot(results_dict):
675
- """
676
- Create model comparison visualization
677
-
678
- Args:
679
- results_dict: Dictionary mapping model names to their results
680
-
681
- Returns:
682
- Plotly figure object
683
- """
684
- try:
685
- if not isinstance(results_dict, dict) or not results_dict:
686
- return GraphVisualizer._create_error_figure("No comparison data available")
687
-
688
- # Extract metrics for comparison
689
- models = []
690
- accuracies = []
691
- f1_scores = []
692
- losses = []
693
-
694
- for model_name, metrics in results_dict.items():
695
- if isinstance(metrics, dict):
696
- models.append(model_name)
697
- accuracies.append(metrics.get('accuracy', 0))
698
- f1_scores.append(metrics.get('f1_macro', 0))
699
- losses.append(metrics.get('loss', float('inf')))
700
-
701
- if not models:
702
- return GraphVisualizer._create_error_figure("No valid model results to compare")
703
-
704
- # Create comparison figure
705
- fig = make_subplots(
706
- rows=1, cols=3,
707
- subplot_titles=('Accuracy Comparison', 'F1 Score Comparison', 'Loss Comparison')
708
- )
709
-
710
- # Accuracy comparison
711
- fig.add_trace(
712
- go.Bar(
713
- x=models,
714
- y=accuracies,
715
- name='Accuracy',
716
- marker_color='lightblue',
717
- text=[f'{acc:.3f}' for acc in accuracies],
718
- textposition='auto',
719
- showlegend=False
720
- ),
721
- row=1, col=1
722
- )
723
-
724
- # F1 Score comparison
725
- fig.add_trace(
726
- go.Bar(
727
- x=models,
728
- y=f1_scores,
729
- name='F1 Score',
730
- marker_color='lightgreen',
731
- text=[f'{f1:.3f}' for f1 in f1_scores],
732
- textposition='auto',
733
- showlegend=False
734
- ),
735
- row=1, col=2
736
- )
737
-
738
- # Loss comparison (filter out infinite values)
739
- finite_losses = [loss if np.isfinite(loss) else 0 for loss in losses]
740
- fig.add_trace(
741
- go.Bar(
742
- x=models,
743
- y=finite_losses,
744
- name='Loss',
745
- marker_color='lightcoral',
746
- text=[f'{loss:.3f}' if np.isfinite(loss) else 'inf' for loss in losses],
747
- textposition='auto',
748
- showlegend=False
749
- ),
750
- row=1, col=3
751
- )
752
-
753
- fig.update_layout(
754
- title=dict(
755
- text='Model Performance Comparison',
756
- x=0.5,
757
- font=dict(size=18)
758
- ),
759
- height=400
760
- )
761
-
762
- # Update axes
763
- fig.update_xaxes(tickangle=45, row=1, col=1)
764
- fig.update_xaxes(tickangle=45, row=1, col=2)
765
- fig.update_xaxes(tickangle=45, row=1, col=3)
766
-
767
- fig.update_yaxes(range=[0, 1], row=1, col=1)
768
- fig.update_yaxes(range=[0, 1], row=1, col=2)
769
-
770
- return fig
771
-
772
- except Exception as e:
773
- logger.error(f"Comparison plot error: {e}")
774
- return GraphVisualizer._create_error_figure(f"Comparison plot error: {str(e)}")
 
11
  logger = logging.getLogger(__name__)
12
 
13
  class GraphVisualizer:
14
+ """Advanced graph visualization utilities"""
 
 
 
15
 
16
  @staticmethod
17
  def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0):
18
+ """Create interactive graph visualization"""
 
 
 
 
 
 
 
 
 
 
 
19
  try:
 
20
  if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'):
21
  raise ValueError("Data must have edge_index and num_nodes attributes")
22
 
 
23
  num_nodes = min(data.num_nodes, max_nodes)
24
  if num_nodes <= 0:
25
  raise ValueError("No nodes to visualize")
 
29
 
30
  if data.edge_index.size(1) > 0:
31
  edge_list = data.edge_index.t().cpu().numpy()
 
 
32
  edge_list = edge_list[
33
  (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
34
  ]
 
36
  if len(edge_list) > 0:
37
  G.add_edges_from(edge_list)
38
 
 
39
  G.add_nodes_from(range(num_nodes))
40
 
41
+ # Layout
42
+ pos = nx.spring_layout(G, seed=42)
43
 
44
+ # Node colors
45
+ if hasattr(data, 'y') and data.y is not None:
46
+ node_colors = data.y.cpu().numpy()[:num_nodes]
47
+ else:
48
+ node_colors = [0] * num_nodes
49
 
50
  # Create edge traces
51
+ edge_x, edge_y = [], []
52
+ for edge in G.edges():
53
+ if edge[0] in pos and edge[1] in pos:
54
+ x0, y0 = pos[edge[0]]
55
+ x1, y1 = pos[edge[1]]
56
+ edge_x.extend([x0, x1, None])
57
+ edge_y.extend([y0, y1, None])
58
 
59
  # Create node traces
60
+ node_x = [pos[node][0] for node in G.nodes()]
61
+ node_y = [pos[node][1] for node in G.nodes()]
 
62
 
 
63
  fig = go.Figure()
64
 
65
  # Add edges
 
69
  line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
70
  hoverinfo='none',
71
  mode='lines',
 
72
  showlegend=False
73
  ))
74
 
 
76
  fig.add_trace(go.Scatter(
77
  x=node_x, y=node_y,
78
  mode='markers',
 
 
 
79
  marker=dict(
80
+ size=8,
81
  color=node_colors,
82
  colorscale='Viridis',
83
  line=dict(width=2, color='white'),
84
+ opacity=0.8
 
85
  ),
86
+ text=[f"Node {i}" for i in range(len(node_x))],
87
+ hoverinfo='text',
88
  showlegend=False
89
  ))
90
 
 
91
  fig.update_layout(
92
+ title=f'Graph Visualization ({num_nodes} nodes)',
 
 
 
 
93
  showlegend=False,
94
  hovermode='closest',
95
  margin=dict(b=20, l=5, r=5, t=40),
 
 
 
 
 
 
 
 
 
 
96
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
97
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
98
  plot_bgcolor='white',
 
106
  logger.error(f"Graph visualization error: {e}")
107
  return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}")
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @staticmethod
110
  def create_metrics_plot(metrics):
111
+ """Create comprehensive metrics visualization"""
 
 
 
 
 
 
 
 
112
  try:
 
113
  metric_names = []
114
  metric_values = []
115
 
 
122
  if not metric_names:
123
  return GraphVisualizer._create_error_figure("No valid metrics to display")
124
 
125
+ fig = make_subplots(
126
+ rows=1, cols=2,
127
+ subplot_titles=('Performance Metrics', 'Metric Radar Chart'),
128
+ specs=[[{"type": "bar"}, {"type": "polar"}]]
129
+ )
130
+
131
+ colors = px.colors.qualitative.Set3[:len(metric_names)]
132
+
133
+ fig.add_trace(
134
+ go.Bar(
135
+ x=metric_names,
136
+ y=metric_values,
137
+ marker_color=colors,
138
+ text=[f'{v:.3f}' for v in metric_values],
139
+ textposition='auto',
140
+ showlegend=False
141
+ ),
142
+ row=1, col=1
143
+ )
144
+
145
+ fig.add_trace(
146
+ go.Scatterpolar(
147
+ r=metric_values + [metric_values[0]],
148
+ theta=metric_names + [metric_names[0]],
149
+ fill='toself',
150
+ line=dict(color='blue'),
151
+ marker=dict(size=8),
152
+ showlegend=False
153
+ ),
154
+ row=1, col=2
155
+ )
156
+
157
+ fig.update_layout(
158
+ title='Model Performance Dashboard',
159
+ height=400,
160
+ showlegend=False
161
+ )
162
+
163
+ fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
164
+ fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
165
+
166
+ fig.update_polars(
167
+ radialaxis=dict(range=[0, 1], showticklabels=True),
168
+ row=1, col=2
169
+ )
170
+
171
+ return fig
172
+
173
+ except Exception as e:
174
+ logger.error(f"Metrics plot error: {e}")
175
+ return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}")
176
+
177
+ @staticmethod
178
+ def create_training_history_plot(history):
179
+ """Create comprehensive training history visualization"""
180
+ try:
181
+ if not isinstance(history, dict) or not history:
182
+ return GraphVisualizer._create_error_figure("No training history available")
183
+
184
+ required_keys = ['train_loss', 'train_acc']
185
+ for key in required_keys:
186
+ if key not in history or not history[key]:
187
+ return GraphVisualizer._create_error_figure(f"Missing {key} in training history")
188
+
189
+ epochs = list(range(len(history['train_loss'])))
190
+
191
+ fig = make_subplots(
192
+ rows=2, cols=2,
193
+ subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'),
194
+ specs=[[{"secondary_y": False}, {"secondary_y": False}],
195
+ [{"secondary_y": False}, {"secondary_y": False}]]
196
+ )
197
+
198
+ # Training loss
199
+ fig.add_trace(
200
+ go.Scatter(
201
+ x=epochs, y=history['train_loss'],
202
+ mode='lines', name='Train Loss',
203
+ line=dict(color='blue', width=2),
204
+ showlegend=False
205
+ ),
206
+ row=1, col=1
207
+ )
208
+
209
+ if 'val_loss' in history and history['val_loss']:
210
+ fig.add_trace(
211
+ go.Scatter(
212
+ x=epochs, y=history['val_loss'],
213
+ mode='lines', name='Val Loss',
214
+ line=dict(color='red', width=2),
215
+ showlegend=False
216
+ ),
217
+ row=1, col=1
218
+ )
219
+
220
+ # Training accuracy
221
+ fig.add_trace(
222
+ go.Scatter(
223
+ x=epochs, y=history['train_acc'],
224
+ mode='lines', name='Train Acc',
225
+ line=dict(color='green', width=2),
226
+ showlegend=False
227
+ ),
228
+ row=1, col=2
229
+ )
230
+
231
+ if 'val_acc' in history and history['val_acc']:
232
+ fig.add_trace(
233
+ go.Scatter(
234
+ x=epochs, y=history['val_acc'],
235
+ mode='lines', name='Val Acc',
236
+ line=dict(color='orange', width=2),
237
+ showlegend=False
238
+ ),
239
+ row=1, col=2
240
+ )
241
+
242
+ # Learning rate
243
+ if 'lr' in history and history['lr']:
244
+ fig.add_trace(
245
+ go.Scatter(
246
+ x=epochs, y=history['lr'],
247
+ mode='lines', name='Learning Rate',
248
+ line=dict(color='purple', width=2),
249
+ showlegend=False
250
+ ),
251
+ row=2, col=1
252
+ )
253
+
254
+ # Training progress summary
255
+ final_metrics = {
256
+ 'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0,
257
+ 'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0,
258
+ }
259
+
260
+ if 'val_acc' in history and history['val_acc']:
261
+ final_metrics['Final Val Acc'] = history['val_acc'][-1]
262
+ final_metrics['Best Val Acc'] = max(history['val_acc'])
263
+
264
+ metric_names = list(final_metrics.keys())
265
+ metric_values = list(final_metrics.values())
266
+
267
+ fig.add_trace(
268
+ go.Bar(
269
+ x=metric_names,
270
+ y=metric_values,
271
+ marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'],
272
+ text=[f'{v:.3f}' for v in metric_values],
273
+ textposition='auto',
274
+ showlegend=False
275
+ ),
276
+ row=2, col=2
277
+ )
278
+
279
+ fig.update_layout(
280
+ title='Training History Dashboard',
281
+ height=600,
282
+ showlegend=True
283
+ )
284
+
285
+ fig.update_xaxes(title_text="Epoch", row=1, col=1)
286
+ fig.update_xaxes(title_text="Epoch", row=1, col=2)
287
+ fig.update_xaxes(title_text="Epoch", row=2, col=1)
288
+ fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2)
289
+
290
+ fig.update_yaxes(title_text="Loss", row=1, col=1)
291
+ fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2)
292
+ fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
293
+ fig.update_yaxes(title_text="Value", row=2, col=2)
294
+
295
+ return fig
296
+
297
+ except Exception as e:
298
+ logger.error(f"Training history plot error: {e}")
299
+ return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}")
300
+
301
+ @staticmethod
302
+ def _create_error_figure(error_message):
303
+ """Create an error figure with message"""
304
+ fig = go.Figure()
305
+ fig.add_annotation(
306
+ text=error_message,
307
+ x=0.5, y=0.5,
308
+ xref="paper", yref="paper",
309
+ showarrow=False,
310
+ font=dict(size=14, color="red"),
311
+ bgcolor="rgba(255,255,255,0.8)",
312
+ bordercolor="red",
313
+ borderwidth=1
314
+ )
315
+ fig.update_layout(
316
+ title="Visualization Error",
317
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
318
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
319
+ plot_bgcolor='white',
320
+ width=600,
321
+ height=400
322
+ )
323
+ return fig