kfoughali commited on
Commit
0470ced
·
verified ·
1 Parent(s): 272c30c

Update utils/visualization.py

Browse files
Files changed (1) hide show
  1. utils/visualization.py +375 -34
utils/visualization.py CHANGED
@@ -1,15 +1,18 @@
1
  import plotly.graph_objects as go
2
  import plotly.express as px
 
 
3
  import networkx as nx
4
  import torch
5
  import numpy as np
 
6
 
7
  class GraphVisualizer:
8
- """Graph visualization utilities"""
9
 
10
  @staticmethod
11
- def create_graph_plot(data, max_nodes=500):
12
- """Create interactive graph visualization"""
13
  try:
14
  # Limit nodes for performance
15
  num_nodes = min(data.num_nodes, max_nodes)
@@ -29,30 +32,71 @@ class GraphVisualizer:
29
  # Add isolated nodes
30
  G.add_nodes_from(range(num_nodes))
31
 
32
- # Layout
33
- if len(G.nodes()) > 100:
34
- pos = nx.spring_layout(G, k=0.5, iterations=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
- pos = nx.spring_layout(G, k=1, iterations=50)
37
 
38
- # Node colors
39
  if hasattr(data, 'y') and data.y is not None:
40
  node_colors = data.y.cpu().numpy()[:num_nodes]
 
 
41
  else:
42
  node_colors = [0] * num_nodes
 
 
 
 
 
 
 
43
 
44
  # Create edge traces
45
  edge_x, edge_y = [], []
 
 
46
  for edge in G.edges():
47
  if edge[0] in pos and edge[1] in pos:
48
  x0, y0 = pos[edge[0]]
49
  x1, y1 = pos[edge[1]]
50
  edge_x.extend([x0, x1, None])
51
  edge_y.extend([y0, y1, None])
 
52
 
53
  # Create node traces
54
- node_x = [pos[node][0] for node in G.nodes() if node in pos]
55
- node_y = [pos[node][1] for node in G.nodes() if node in pos]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  fig = go.Figure()
58
 
@@ -60,10 +104,11 @@ class GraphVisualizer:
60
  if edge_x:
61
  fig.add_trace(go.Scatter(
62
  x=edge_x, y=edge_y,
63
- line=dict(width=0.5, color='#888'),
64
  hoverinfo='none',
65
  mode='lines',
66
- name='Edges'
 
67
  ))
68
 
69
  # Add nodes
@@ -71,24 +116,43 @@ class GraphVisualizer:
71
  x=node_x, y=node_y,
72
  mode='markers',
73
  hoverinfo='text',
74
- text=[f'Node {i}' for i in range(len(node_x))],
 
75
  marker=dict(
76
- size=8,
77
  color=node_colors[:len(node_x)],
78
  colorscale='Viridis',
79
- line=dict(width=1)
 
80
  ),
81
- name='Nodes'
 
82
  ))
83
 
84
  fig.update_layout(
85
- title=f'Graph Visualization ({num_nodes} nodes)',
 
 
 
 
86
  showlegend=False,
87
  hovermode='closest',
88
  margin=dict(b=20, l=5, r=5, t=40),
 
 
 
 
 
 
 
 
 
 
89
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
90
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
91
- plot_bgcolor='white'
 
 
92
  )
93
 
94
  return fig
@@ -100,45 +164,96 @@ class GraphVisualizer:
100
  text=f"Visualization error: {str(e)}",
101
  x=0.5, y=0.5,
102
  xref="paper", yref="paper",
103
- showarrow=False
 
 
 
 
 
 
 
104
  )
105
  return fig
106
 
107
  @staticmethod
108
  def create_metrics_plot(metrics):
109
- """Create metrics visualization"""
110
  try:
 
111
  metric_names = []
112
  metric_values = []
113
 
114
  for key, value in metrics.items():
115
- if isinstance(value, (int, float)) and key != 'error':
116
- metric_names.append(key.replace('_', ' ').title())
117
- metric_values.append(value)
 
118
 
119
  if metric_names:
120
- fig = go.Figure(data=[
 
 
 
 
 
 
 
 
 
 
121
  go.Bar(
122
  x=metric_names,
123
  y=metric_values,
124
- marker_color='lightblue'
125
- )
126
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  fig.update_layout(
129
- title='Model Performance Metrics',
130
- xaxis_title='Metric',
131
- yaxis_title='Value',
132
- yaxis=dict(range=[0, 1])
 
 
 
133
  )
 
 
 
 
 
 
 
 
 
 
 
134
  else:
135
  fig = go.Figure()
136
  fig.add_annotation(
137
- text="No metrics to display",
138
  x=0.5, y=0.5,
139
  xref="paper", yref="paper",
140
- showarrow=False
 
141
  )
 
142
 
143
  return fig
144
 
@@ -148,6 +263,232 @@ class GraphVisualizer:
148
  text=f"Metrics plot error: {str(e)}",
149
  x=0.5, y=0.5,
150
  xref="paper", yref="paper",
151
- showarrow=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
  return fig
 
1
  import plotly.graph_objects as go
2
  import plotly.express as px
3
+ import plotly.figure_factory as ff
4
+ from plotly.subplots import make_subplots
5
  import networkx as nx
6
  import torch
7
  import numpy as np
8
+ import pandas as pd
9
 
10
  class GraphVisualizer:
11
+ """Advanced graph visualization utilities"""
12
 
13
  @staticmethod
14
+ def create_graph_plot(data, max_nodes=500, layout_algorithm='spring'):
15
+ """Create interactive graph visualization with multiple layout options"""
16
  try:
17
  # Limit nodes for performance
18
  num_nodes = min(data.num_nodes, max_nodes)
 
32
  # Add isolated nodes
33
  G.add_nodes_from(range(num_nodes))
34
 
35
+ # Choose layout algorithm
36
+ if layout_algorithm == 'spring':
37
+ if len(G.nodes()) > 100:
38
+ pos = nx.spring_layout(G, k=0.5, iterations=20)
39
+ else:
40
+ pos = nx.spring_layout(G, k=1, iterations=50)
41
+ elif layout_algorithm == 'circular':
42
+ pos = nx.circular_layout(G)
43
+ elif layout_algorithm == 'kamada_kawai':
44
+ try:
45
+ pos = nx.kamada_kawai_layout(G)
46
+ except:
47
+ pos = nx.spring_layout(G)
48
+ elif layout_algorithm == 'spectral':
49
+ try:
50
+ pos = nx.spectral_layout(G)
51
+ except:
52
+ pos = nx.spring_layout(G)
53
  else:
54
+ pos = nx.spring_layout(G)
55
 
56
+ # Node colors and sizes
57
  if hasattr(data, 'y') and data.y is not None:
58
  node_colors = data.y.cpu().numpy()[:num_nodes]
59
+ unique_labels = np.unique(node_colors)
60
+ color_map = px.colors.qualitative.Set3[:len(unique_labels)]
61
  else:
62
  node_colors = [0] * num_nodes
63
+ color_map = ['lightblue']
64
+
65
+ # Node sizes based on degree
66
+ node_sizes = []
67
+ for node in G.nodes():
68
+ degree = G.degree(node)
69
+ node_sizes.append(max(5, min(20, 5 + degree)))
70
 
71
  # Create edge traces
72
  edge_x, edge_y = [], []
73
+ edge_info = []
74
+
75
  for edge in G.edges():
76
  if edge[0] in pos and edge[1] in pos:
77
  x0, y0 = pos[edge[0]]
78
  x1, y1 = pos[edge[1]]
79
  edge_x.extend([x0, x1, None])
80
  edge_y.extend([y0, y1, None])
81
+ edge_info.append(f"Edge: {edge[0]} - {edge[1]}")
82
 
83
  # Create node traces
84
+ node_x = []
85
+ node_y = []
86
+ node_text = []
87
+ node_info = []
88
+
89
+ for node in G.nodes():
90
+ if node in pos:
91
+ x, y = pos[node]
92
+ node_x.append(x)
93
+ node_y.append(y)
94
+
95
+ # Node info
96
+ degree = G.degree(node)
97
+ label = node_colors[node] if node < len(node_colors) else 0
98
+ node_text.append(f"Node {node}")
99
+ node_info.append(f"Node: {node}<br>Degree: {degree}<br>Label: {label}")
100
 
101
  fig = go.Figure()
102
 
 
104
  if edge_x:
105
  fig.add_trace(go.Scatter(
106
  x=edge_x, y=edge_y,
107
+ line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
108
  hoverinfo='none',
109
  mode='lines',
110
+ name='Edges',
111
+ showlegend=False
112
  ))
113
 
114
  # Add nodes
 
116
  x=node_x, y=node_y,
117
  mode='markers',
118
  hoverinfo='text',
119
+ hovertext=node_info,
120
+ text=node_text,
121
  marker=dict(
122
+ size=node_sizes,
123
  color=node_colors[:len(node_x)],
124
  colorscale='Viridis',
125
+ line=dict(width=2, color='white'),
126
+ opacity=0.8
127
  ),
128
+ name='Nodes',
129
+ showlegend=False
130
  ))
131
 
132
  fig.update_layout(
133
+ title=dict(
134
+ text=f'Graph Visualization ({num_nodes} nodes, {len(edge_list)} edges)',
135
+ x=0.5,
136
+ font=dict(size=16)
137
+ ),
138
  showlegend=False,
139
  hovermode='closest',
140
  margin=dict(b=20, l=5, r=5, t=40),
141
+ annotations=[
142
+ dict(
143
+ text=f"Layout: {layout_algorithm.title()}",
144
+ showarrow=False,
145
+ xref="paper", yref="paper",
146
+ x=0.005, y=-0.002,
147
+ xanchor='left', yanchor='bottom',
148
+ font=dict(color="gray", size=10)
149
+ )
150
+ ],
151
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
152
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
153
+ plot_bgcolor='white',
154
+ width=800,
155
+ height=600
156
  )
157
 
158
  return fig
 
164
  text=f"Visualization error: {str(e)}",
165
  x=0.5, y=0.5,
166
  xref="paper", yref="paper",
167
+ showarrow=False,
168
+ font=dict(size=14, color="red")
169
+ )
170
+ fig.update_layout(
171
+ title="Graph Visualization Error",
172
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
173
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
174
+ plot_bgcolor='white'
175
  )
176
  return fig
177
 
178
  @staticmethod
179
  def create_metrics_plot(metrics):
180
+ """Create comprehensive metrics visualization"""
181
  try:
182
+ # Filter numeric metrics
183
  metric_names = []
184
  metric_values = []
185
 
186
  for key, value in metrics.items():
187
+ if isinstance(value, (int, float)) and key not in ['error', 'loss']:
188
+ if not (np.isnan(value) or np.isinf(value)):
189
+ metric_names.append(key.replace('_', ' ').title())
190
+ metric_values.append(value)
191
 
192
  if metric_names:
193
+ # Create subplots
194
+ fig = make_subplots(
195
+ rows=1, cols=2,
196
+ subplot_titles=('Performance Metrics', 'Metric Comparison'),
197
+ specs=[[{"type": "bar"}, {"type": "scatter"}]]
198
+ )
199
+
200
+ # Bar chart
201
+ colors = px.colors.qualitative.Set3[:len(metric_names)]
202
+
203
+ fig.add_trace(
204
  go.Bar(
205
  x=metric_names,
206
  y=metric_values,
207
+ marker_color=colors,
208
+ text=[f'{v:.3f}' for v in metric_values],
209
+ textposition='auto',
210
+ name='Metrics'
211
+ ),
212
+ row=1, col=1
213
+ )
214
+
215
+ # Radar chart data
216
+ fig.add_trace(
217
+ go.Scatterpolar(
218
+ r=metric_values,
219
+ theta=metric_names,
220
+ fill='toself',
221
+ name='Performance',
222
+ line=dict(color='blue')
223
+ ),
224
+ row=1, col=2
225
+ )
226
 
227
  fig.update_layout(
228
+ title=dict(
229
+ text='Model Performance Dashboard',
230
+ x=0.5,
231
+ font=dict(size=18)
232
+ ),
233
+ showlegend=False,
234
+ height=400
235
  )
236
+
237
+ # Update bar chart
238
+ fig.update_xaxes(title_text="Metrics", row=1, col=1)
239
+ fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
240
+
241
+ # Update polar chart
242
+ fig.update_polars(
243
+ radialaxis=dict(range=[0, 1], showticklabels=True),
244
+ row=1, col=2
245
+ )
246
+
247
  else:
248
  fig = go.Figure()
249
  fig.add_annotation(
250
+ text="No valid metrics to display",
251
  x=0.5, y=0.5,
252
  xref="paper", yref="paper",
253
+ showarrow=False,
254
+ font=dict(size=14)
255
  )
256
+ fig.update_layout(title="Metrics Dashboard")
257
 
258
  return fig
259
 
 
263
  text=f"Metrics plot error: {str(e)}",
264
  x=0.5, y=0.5,
265
  xref="paper", yref="paper",
266
+ showarrow=False,
267
+ font=dict(size=14, color="red")
268
+ )
269
+ fig.update_layout(title="Metrics Error")
270
+ return fig
271
+
272
+ @staticmethod
273
+ def create_training_history_plot(history):
274
+ """Create training history visualization"""
275
+ try:
276
+ epochs = list(range(len(history['train_loss'])))
277
+
278
+ # Create subplots
279
+ fig = make_subplots(
280
+ rows=2, cols=2,
281
+ subplot_titles=('Training Loss', 'Training Accuracy', 'Learning Rate', 'Loss Comparison'),
282
+ specs=[[{"secondary_y": False}, {"secondary_y": False}],
283
+ [{"secondary_y": False}, {"secondary_y": False}]]
284
+ )
285
+
286
+ # Training loss
287
+ fig.add_trace(
288
+ go.Scatter(
289
+ x=epochs, y=history['train_loss'],
290
+ mode='lines', name='Train Loss',
291
+ line=dict(color='blue', width=2)
292
+ ),
293
+ row=1, col=1
294
+ )
295
+
296
+ if 'val_loss' in history:
297
+ fig.add_trace(
298
+ go.Scatter(
299
+ x=epochs, y=history['val_loss'],
300
+ mode='lines', name='Val Loss',
301
+ line=dict(color='red', width=2)
302
+ ),
303
+ row=1, col=1
304
+ )
305
+
306
+ # Training accuracy
307
+ fig.add_trace(
308
+ go.Scatter(
309
+ x=epochs, y=history['train_acc'],
310
+ mode='lines', name='Train Acc',
311
+ line=dict(color='green', width=2)
312
+ ),
313
+ row=1, col=2
314
+ )
315
+
316
+ if 'val_acc' in history:
317
+ fig.add_trace(
318
+ go.Scatter(
319
+ x=epochs, y=history['val_acc'],
320
+ mode='lines', name='Val Acc',
321
+ line=dict(color='orange', width=2)
322
+ ),
323
+ row=1, col=2
324
+ )
325
+
326
+ # Learning rate
327
+ if 'lr' in history:
328
+ fig.add_trace(
329
+ go.Scatter(
330
+ x=epochs, y=history['lr'],
331
+ mode='lines', name='Learning Rate',
332
+ line=dict(color='purple', width=2)
333
+ ),
334
+ row=2, col=1
335
+ )
336
+
337
+ # Loss comparison
338
+ if 'train_loss' in history and 'val_loss' in history:
339
+ fig.add_trace(
340
+ go.Scatter(
341
+ x=history['train_loss'], y=history['val_loss'],
342
+ mode='markers', name='Train vs Val Loss',
343
+ marker=dict(color=epochs, colorscale='Viridis', size=8),
344
+ text=[f'Epoch {i}' for i in epochs],
345
+ hovertemplate='Train Loss: %{x:.4f}<br>Val Loss: %{y:.4f}<br>%{text}'
346
+ ),
347
+ row=2, col=2
348
+ )
349
+
350
+ # Add diagonal line
351
+ min_loss = min(min(history['train_loss']), min(history['val_loss']))
352
+ max_loss = max(max(history['train_loss']), max(history['val_loss']))
353
+ fig.add_trace(
354
+ go.Scatter(
355
+ x=[min_loss, max_loss], y=[min_loss, max_loss],
356
+ mode='lines', name='Perfect Fit',
357
+ line=dict(color='gray', dash='dash'),
358
+ showlegend=False
359
+ ),
360
+ row=2, col=2
361
+ )
362
+
363
+ fig.update_layout(
364
+ title=dict(
365
+ text='Training History Dashboard',
366
+ x=0.5,
367
+ font=dict(size=18)
368
+ ),
369
+ height=600,
370
+ showlegend=True
371
+ )
372
+
373
+ # Update axes
374
+ fig.update_xaxes(title_text="Epoch", row=1, col=1)
375
+ fig.update_xaxes(title_text="Epoch", row=1, col=2)
376
+ fig.update_xaxes(title_text="Epoch", row=2, col=1)
377
+ fig.update_xaxes(title_text="Train Loss", row=2, col=2)
378
+
379
+ fig.update_yaxes(title_text="Loss", row=1, col=1)
380
+ fig.update_yaxes(title_text="Accuracy", row=1, col=2)
381
+ fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
382
+ fig.update_yaxes(title_text="Val Loss", row=2, col=2)
383
+
384
+ return fig
385
+
386
+ except Exception as e:
387
+ fig = go.Figure()
388
+ fig.add_annotation(
389
+ text=f"Training history plot error: {str(e)}",
390
+ x=0.5, y=0.5,
391
+ xref="paper", yref="paper",
392
+ showarrow=False,
393
+ font=dict(size=14, color="red")
394
+ )
395
+ return fig
396
+
397
+ @staticmethod
398
+ def create_dataset_stats_plot(dataset_info):
399
+ """Create dataset statistics visualization"""
400
+ try:
401
+ # Prepare data
402
+ stats_data = []
403
+ for key, value in dataset_info.items():
404
+ if isinstance(value, (int, float)) and not np.isnan(value):
405
+ stats_data.append({
406
+ 'Metric': key.replace('_', ' ').title(),
407
+ 'Value': value
408
+ })
409
+
410
+ if not stats_data:
411
+ raise ValueError("No valid statistics to display")
412
+
413
+ df = pd.DataFrame(stats_data)
414
+
415
+ # Create subplots
416
+ fig = make_subplots(
417
+ rows=1, cols=2,
418
+ subplot_titles=('Dataset Overview', 'Graph Size Distribution'),
419
+ specs=[[{"type": "bar"}, {"type": "box"}]]
420
+ )
421
+
422
+ # Bar chart of statistics
423
+ fig.add_trace(
424
+ go.Bar(
425
+ x=df['Metric'],
426
+ y=df['Value'],
427
+ marker_color=px.colors.qualitative.Pastel1,
428
+ text=df['Value'],
429
+ texttemplate='%{text:,.0f}',
430
+ textposition='auto'
431
+ ),
432
+ row=1, col=1
433
+ )
434
+
435
+ # Box plot for size distribution (if multiple graphs)
436
+ if dataset_info.get('num_graphs', 1) > 1:
437
+ # Simulate distribution based on min/max/avg
438
+ avg_nodes = dataset_info.get('avg_nodes', 100)
439
+ min_nodes = dataset_info.get('min_nodes', avg_nodes * 0.5)
440
+ max_nodes = dataset_info.get('max_nodes', avg_nodes * 1.5)
441
+
442
+ # Generate synthetic distribution
443
+ np.random.seed(42)
444
+ node_dist = np.random.normal(avg_nodes, (max_nodes - min_nodes) / 4, 100)
445
+ node_dist = np.clip(node_dist, min_nodes, max_nodes)
446
+
447
+ fig.add_trace(
448
+ go.Box(
449
+ y=node_dist,
450
+ name='Node Count',
451
+ marker_color='lightblue'
452
+ ),
453
+ row=1, col=2
454
+ )
455
+ else:
456
+ # Single graph - show as point
457
+ fig.add_trace(
458
+ go.Scatter(
459
+ x=['Nodes'],
460
+ y=[dataset_info.get('avg_nodes', 0)],
461
+ mode='markers',
462
+ marker=dict(size=20, color='blue'),
463
+ name='Node Count'
464
+ ),
465
+ row=1, col=2
466
+ )
467
+
468
+ fig.update_layout(
469
+ title=dict(
470
+ text='Dataset Statistics Dashboard',
471
+ x=0.5,
472
+ font=dict(size=16)
473
+ ),
474
+ height=400,
475
+ showlegend=False
476
+ )
477
+
478
+ # Update axes
479
+ fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
480
+ fig.update_yaxes(title_text="Count", row=1, col=1)
481
+ fig.update_yaxes(title_text="Number of Nodes", row=1, col=2)
482
+
483
+ return fig
484
+
485
+ except Exception as e:
486
+ fig = go.Figure()
487
+ fig.add_annotation(
488
+ text=f"Dataset stats error: {str(e)}",
489
+ x=0.5, y=0.5,
490
+ xref="paper", yref="paper",
491
+ showarrow=False,
492
+ font=dict(size=14, color="red")
493
  )
494
  return fig