Update utils/visualization.py
Browse files- utils/visualization.py +223 -674
utils/visualization.py
CHANGED
@@ -11,31 +11,15 @@ import logging
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
class GraphVisualizer:
|
14 |
-
"""
|
15 |
-
Advanced graph visualization utilities
|
16 |
-
Enterprise-grade with comprehensive error handling
|
17 |
-
"""
|
18 |
|
19 |
@staticmethod
|
20 |
def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0):
|
21 |
-
"""
|
22 |
-
Create interactive graph visualization with robust error handling
|
23 |
-
|
24 |
-
Args:
|
25 |
-
data: PyTorch Geometric data object
|
26 |
-
max_nodes: Maximum number of nodes to display
|
27 |
-
layout_algorithm: Layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')
|
28 |
-
node_size_factor: Factor to scale node sizes
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
Plotly figure object
|
32 |
-
"""
|
33 |
try:
|
34 |
-
# Validate inputs
|
35 |
if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'):
|
36 |
raise ValueError("Data must have edge_index and num_nodes attributes")
|
37 |
|
38 |
-
# Limit nodes for performance
|
39 |
num_nodes = min(data.num_nodes, max_nodes)
|
40 |
if num_nodes <= 0:
|
41 |
raise ValueError("No nodes to visualize")
|
@@ -45,8 +29,6 @@ class GraphVisualizer:
|
|
45 |
|
46 |
if data.edge_index.size(1) > 0:
|
47 |
edge_list = data.edge_index.t().cpu().numpy()
|
48 |
-
|
49 |
-
# Filter edges to include only first max_nodes
|
50 |
edge_list = edge_list[
|
51 |
(edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
|
52 |
]
|
@@ -54,25 +36,30 @@ class GraphVisualizer:
|
|
54 |
if len(edge_list) > 0:
|
55 |
G.add_edges_from(edge_list)
|
56 |
|
57 |
-
# Add isolated nodes
|
58 |
G.add_nodes_from(range(num_nodes))
|
59 |
|
60 |
-
#
|
61 |
-
pos =
|
62 |
|
63 |
-
# Node colors
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
|
67 |
# Create edge traces
|
68 |
-
edge_x, edge_y
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# Create node traces
|
71 |
-
node_x
|
72 |
-
|
73 |
-
)
|
74 |
|
75 |
-
# Create figure
|
76 |
fig = go.Figure()
|
77 |
|
78 |
# Add edges
|
@@ -82,7 +69,6 @@ class GraphVisualizer:
|
|
82 |
line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
|
83 |
hoverinfo='none',
|
84 |
mode='lines',
|
85 |
-
name='Edges',
|
86 |
showlegend=False
|
87 |
))
|
88 |
|
@@ -90,41 +76,23 @@ class GraphVisualizer:
|
|
90 |
fig.add_trace(go.Scatter(
|
91 |
x=node_x, y=node_y,
|
92 |
mode='markers',
|
93 |
-
hoverinfo='text',
|
94 |
-
hovertext=node_info,
|
95 |
-
text=node_text,
|
96 |
marker=dict(
|
97 |
-
size=
|
98 |
color=node_colors,
|
99 |
colorscale='Viridis',
|
100 |
line=dict(width=2, color='white'),
|
101 |
-
opacity=0.8
|
102 |
-
colorbar=dict(title="Node Label") if hasattr(data, 'y') and data.y is not None else None
|
103 |
),
|
104 |
-
|
|
|
105 |
showlegend=False
|
106 |
))
|
107 |
|
108 |
-
# Update layout
|
109 |
fig.update_layout(
|
110 |
-
title=
|
111 |
-
text=f'Graph Visualization ({num_nodes} nodes, {len(edge_x)//3 if edge_x else 0} edges)',
|
112 |
-
x=0.5,
|
113 |
-
font=dict(size=16)
|
114 |
-
),
|
115 |
showlegend=False,
|
116 |
hovermode='closest',
|
117 |
margin=dict(b=20, l=5, r=5, t=40),
|
118 |
-
annotations=[
|
119 |
-
dict(
|
120 |
-
text=f"Layout: {layout_algorithm.title()}",
|
121 |
-
showarrow=False,
|
122 |
-
xref="paper", yref="paper",
|
123 |
-
x=0.005, y=-0.002,
|
124 |
-
xanchor='left', yanchor='bottom',
|
125 |
-
font=dict(color="gray", size=10)
|
126 |
-
)
|
127 |
-
],
|
128 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
129 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
130 |
plot_bgcolor='white',
|
@@ -138,126 +106,10 @@ class GraphVisualizer:
|
|
138 |
logger.error(f"Graph visualization error: {e}")
|
139 |
return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}")
|
140 |
|
141 |
-
@staticmethod
|
142 |
-
def _compute_layout(G, algorithm):
|
143 |
-
"""Compute graph layout with fallback options"""
|
144 |
-
try:
|
145 |
-
if algorithm == 'spring':
|
146 |
-
if len(G.nodes()) > 100:
|
147 |
-
return nx.spring_layout(G, k=0.5, iterations=20, seed=42)
|
148 |
-
else:
|
149 |
-
return nx.spring_layout(G, k=1, iterations=50, seed=42)
|
150 |
-
elif algorithm == 'circular':
|
151 |
-
return nx.circular_layout(G)
|
152 |
-
elif algorithm == 'kamada_kawai':
|
153 |
-
if len(G.nodes()) <= 500: # Too slow for large graphs
|
154 |
-
return nx.kamada_kawai_layout(G)
|
155 |
-
else:
|
156 |
-
raise ValueError("Too many nodes for Kamada-Kawai")
|
157 |
-
elif algorithm == 'spectral':
|
158 |
-
if len(G.edges()) > 0:
|
159 |
-
return nx.spectral_layout(G)
|
160 |
-
else:
|
161 |
-
return nx.circular_layout(G)
|
162 |
-
else:
|
163 |
-
return nx.spring_layout(G, seed=42)
|
164 |
-
except Exception as e:
|
165 |
-
logger.warning(f"Layout algorithm {algorithm} failed: {e}, using spring layout")
|
166 |
-
return nx.spring_layout(G, seed=42)
|
167 |
-
|
168 |
-
@staticmethod
|
169 |
-
def _get_node_colors(data, num_nodes):
|
170 |
-
"""Get node colors based on labels"""
|
171 |
-
try:
|
172 |
-
if hasattr(data, 'y') and data.y is not None:
|
173 |
-
node_colors = data.y.cpu().numpy()[:num_nodes]
|
174 |
-
unique_labels = np.unique(node_colors)
|
175 |
-
color_map = px.colors.qualitative.Set3[:len(unique_labels)]
|
176 |
-
else:
|
177 |
-
node_colors = [0] * num_nodes
|
178 |
-
color_map = ['lightblue']
|
179 |
-
|
180 |
-
return node_colors, color_map
|
181 |
-
except Exception as e:
|
182 |
-
logger.warning(f"Node color computation failed: {e}")
|
183 |
-
return [0] * num_nodes, ['lightblue']
|
184 |
-
|
185 |
-
@staticmethod
|
186 |
-
def _get_node_sizes(G, size_factor):
|
187 |
-
"""Get node sizes based on degree"""
|
188 |
-
try:
|
189 |
-
node_sizes = []
|
190 |
-
for node in G.nodes():
|
191 |
-
degree = G.degree(node)
|
192 |
-
size = max(5, min(20, 5 + degree * 2)) * size_factor
|
193 |
-
node_sizes.append(size)
|
194 |
-
return node_sizes
|
195 |
-
except Exception as e:
|
196 |
-
logger.warning(f"Node size computation failed: {e}")
|
197 |
-
return [8] * len(G.nodes())
|
198 |
-
|
199 |
-
@staticmethod
|
200 |
-
def _create_edge_traces(G, pos):
|
201 |
-
"""Create edge traces for plotting"""
|
202 |
-
try:
|
203 |
-
edge_x, edge_y, edge_info = [], [], []
|
204 |
-
|
205 |
-
for edge in G.edges():
|
206 |
-
if edge[0] in pos and edge[1] in pos:
|
207 |
-
x0, y0 = pos[edge[0]]
|
208 |
-
x1, y1 = pos[edge[1]]
|
209 |
-
edge_x.extend([x0, x1, None])
|
210 |
-
edge_y.extend([y0, y1, None])
|
211 |
-
edge_info.append(f"Edge: {edge[0]} - {edge[1]}")
|
212 |
-
|
213 |
-
return edge_x, edge_y, edge_info
|
214 |
-
except Exception as e:
|
215 |
-
logger.warning(f"Edge trace creation failed: {e}")
|
216 |
-
return [], [], []
|
217 |
-
|
218 |
-
@staticmethod
|
219 |
-
def _create_node_traces(G, pos, node_colors, data):
|
220 |
-
"""Create node traces for plotting"""
|
221 |
-
try:
|
222 |
-
node_x, node_y, node_text, node_info = [], [], [], []
|
223 |
-
|
224 |
-
for node in G.nodes():
|
225 |
-
if node in pos:
|
226 |
-
x, y = pos[node]
|
227 |
-
node_x.append(x)
|
228 |
-
node_y.append(y)
|
229 |
-
|
230 |
-
# Node info
|
231 |
-
degree = G.degree(node)
|
232 |
-
label = node_colors[node] if node < len(node_colors) else 0
|
233 |
-
node_text.append(f"Node {node}")
|
234 |
-
|
235 |
-
# Enhanced node info
|
236 |
-
info = f"Node: {node}<br>Degree: {degree}<br>Label: {label}"
|
237 |
-
if hasattr(data, 'x') and data.x is not None and node < data.x.size(0):
|
238 |
-
feature_sum = data.x[node].sum().item()
|
239 |
-
info += f"<br>Feature Sum: {feature_sum:.2f}"
|
240 |
-
|
241 |
-
node_info.append(info)
|
242 |
-
|
243 |
-
return node_x, node_y, node_text, node_info
|
244 |
-
except Exception as e:
|
245 |
-
logger.warning(f"Node trace creation failed: {e}")
|
246 |
-
return [], [], [], []
|
247 |
-
|
248 |
@staticmethod
|
249 |
def create_metrics_plot(metrics):
|
250 |
-
"""
|
251 |
-
Create comprehensive metrics visualization
|
252 |
-
|
253 |
-
Args:
|
254 |
-
metrics: Dictionary of metrics
|
255 |
-
|
256 |
-
Returns:
|
257 |
-
Plotly figure object
|
258 |
-
"""
|
259 |
try:
|
260 |
-
# Filter and validate numeric metrics
|
261 |
metric_names = []
|
262 |
metric_values = []
|
263 |
|
@@ -270,505 +122,202 @@ class GraphVisualizer:
|
|
270 |
if not metric_names:
|
271 |
return GraphVisualizer._create_error_figure("No valid metrics to display")
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
return fig
|
473 |
-
|
474 |
-
except Exception as e:
|
475 |
-
logger.error(f"Training history plot error: {e}")
|
476 |
-
return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}")
|
477 |
-
|
478 |
-
@staticmethod
|
479 |
-
def create_dataset_stats_plot(dataset_info):
|
480 |
-
"""
|
481 |
-
Create dataset statistics visualization
|
482 |
-
|
483 |
-
Args:
|
484 |
-
dataset_info: Dictionary containing dataset statistics
|
485 |
-
|
486 |
-
Returns:
|
487 |
-
Plotly figure object
|
488 |
-
"""
|
489 |
-
try:
|
490 |
-
if not isinstance(dataset_info, dict) or not dataset_info:
|
491 |
-
return GraphVisualizer._create_error_figure("No dataset information available")
|
492 |
-
|
493 |
-
# Prepare data
|
494 |
-
stats_data = []
|
495 |
-
for key, value in dataset_info.items():
|
496 |
-
if isinstance(value, (int, float)) and not np.isnan(value) and np.isfinite(value):
|
497 |
-
stats_data.append({
|
498 |
-
'Metric': key.replace('_', ' ').title(),
|
499 |
-
'Value': value
|
500 |
-
})
|
501 |
-
|
502 |
-
if not stats_data:
|
503 |
-
return GraphVisualizer._create_error_figure("No valid statistics to display")
|
504 |
-
|
505 |
-
df = pd.DataFrame(stats_data)
|
506 |
-
|
507 |
-
# Create subplots
|
508 |
-
fig = make_subplots(
|
509 |
-
rows=2, cols=2,
|
510 |
-
subplot_titles=('Dataset Overview', 'Graph Size Distribution', 'Feature Statistics', 'Connectivity'),
|
511 |
-
specs=[[{"type": "bar"}, {"type": "histogram"}],
|
512 |
-
[{"type": "bar"}, {"type": "bar"}]]
|
513 |
-
)
|
514 |
-
|
515 |
-
# Main statistics bar chart
|
516 |
-
main_metrics = ['Num Features', 'Num Classes', 'Num Graphs', 'Total Nodes', 'Total Edges']
|
517 |
-
main_data = df[df['Metric'].isin(main_metrics)]
|
518 |
-
|
519 |
-
if not main_data.empty:
|
520 |
-
fig.add_trace(
|
521 |
-
go.Bar(
|
522 |
-
x=main_data['Metric'],
|
523 |
-
y=main_data['Value'],
|
524 |
-
marker_color=px.colors.qualitative.Pastel1[:len(main_data)],
|
525 |
-
text=[f'{int(v):,}' if v >= 1 else f'{v:.2f}' for v in main_data['Value']],
|
526 |
-
textposition='auto',
|
527 |
-
showlegend=False
|
528 |
-
),
|
529 |
-
row=1, col=1
|
530 |
-
)
|
531 |
-
|
532 |
-
# Graph size distribution
|
533 |
-
if 'num_graphs' in dataset_info and dataset_info['num_graphs'] > 1:
|
534 |
-
# Simulate distribution based on min/max/avg
|
535 |
-
avg_nodes = dataset_info.get('avg_nodes', 100)
|
536 |
-
min_nodes = dataset_info.get('min_nodes', avg_nodes * 0.5)
|
537 |
-
max_nodes = dataset_info.get('max_nodes', avg_nodes * 1.5)
|
538 |
-
|
539 |
-
# Generate realistic distribution
|
540 |
-
np.random.seed(42)
|
541 |
-
if max_nodes > min_nodes:
|
542 |
-
node_dist = np.random.lognormal(
|
543 |
-
mean=np.log(avg_nodes),
|
544 |
-
sigma=0.5,
|
545 |
-
size=min(100, int(dataset_info['num_graphs']))
|
546 |
-
)
|
547 |
-
node_dist = np.clip(node_dist, min_nodes, max_nodes)
|
548 |
-
else:
|
549 |
-
node_dist = [avg_nodes] * min(100, int(dataset_info['num_graphs']))
|
550 |
-
|
551 |
-
fig.add_trace(
|
552 |
-
go.Histogram(
|
553 |
-
x=node_dist,
|
554 |
-
nbinsx=20,
|
555 |
-
marker_color='lightblue',
|
556 |
-
opacity=0.7,
|
557 |
-
showlegend=False
|
558 |
-
),
|
559 |
-
row=1, col=2
|
560 |
-
)
|
561 |
-
else:
|
562 |
-
# Single graph - show as point
|
563 |
-
fig.add_trace(
|
564 |
-
go.Scatter(
|
565 |
-
x=['Nodes'],
|
566 |
-
y=[dataset_info.get('avg_nodes', 0)],
|
567 |
-
mode='markers',
|
568 |
-
marker=dict(size=20, color='blue'),
|
569 |
-
showlegend=False
|
570 |
-
),
|
571 |
-
row=1, col=2
|
572 |
-
)
|
573 |
-
|
574 |
-
# Feature statistics
|
575 |
-
feature_metrics = ['Avg Nodes', 'Avg Edges', 'Avg Degree']
|
576 |
-
feature_data = df[df['Metric'].isin(feature_metrics)]
|
577 |
-
|
578 |
-
if not feature_data.empty:
|
579 |
-
fig.add_trace(
|
580 |
-
go.Bar(
|
581 |
-
x=feature_data['Metric'],
|
582 |
-
y=feature_data['Value'],
|
583 |
-
marker_color=['lightgreen', 'lightcoral', 'lightyellow'],
|
584 |
-
text=[f'{v:.1f}' for v in feature_data['Value']],
|
585 |
-
textposition='auto',
|
586 |
-
showlegend=False
|
587 |
-
),
|
588 |
-
row=2, col=1
|
589 |
-
)
|
590 |
-
|
591 |
-
# Connectivity analysis
|
592 |
-
connectivity_data = []
|
593 |
-
if 'total_nodes' in dataset_info and 'total_edges' in dataset_info:
|
594 |
-
total_nodes = dataset_info['total_nodes']
|
595 |
-
total_edges = dataset_info['total_edges']
|
596 |
-
|
597 |
-
if total_nodes > 0:
|
598 |
-
max_possible_edges = total_nodes * (total_nodes - 1) / 2
|
599 |
-
density = total_edges / max_possible_edges if max_possible_edges > 0 else 0
|
600 |
-
avg_degree = dataset_info.get('avg_degree', 0)
|
601 |
-
|
602 |
-
connectivity_data = [
|
603 |
-
{'Metric': 'Graph Density', 'Value': density},
|
604 |
-
{'Metric': 'Avg Degree', 'Value': avg_degree / total_nodes if total_nodes > 0 else 0},
|
605 |
-
{'Metric': 'Edge Ratio', 'Value': total_edges / total_nodes if total_nodes > 0 else 0}
|
606 |
-
]
|
607 |
-
|
608 |
-
if connectivity_data:
|
609 |
-
conn_df = pd.DataFrame(connectivity_data)
|
610 |
-
fig.add_trace(
|
611 |
-
go.Bar(
|
612 |
-
x=conn_df['Metric'],
|
613 |
-
y=conn_df['Value'],
|
614 |
-
marker_color=['lightpink', 'lightsteelblue', 'lightgoldenrodyellow'],
|
615 |
-
text=[f'{v:.3f}' for v in conn_df['Value']],
|
616 |
-
textposition='auto',
|
617 |
-
showlegend=False
|
618 |
-
),
|
619 |
-
row=2, col=2
|
620 |
-
)
|
621 |
-
|
622 |
-
fig.update_layout(
|
623 |
-
title=dict(
|
624 |
-
text='Dataset Statistics Dashboard',
|
625 |
-
x=0.5,
|
626 |
-
font=dict(size=16)
|
627 |
-
),
|
628 |
-
height=600,
|
629 |
-
showlegend=False
|
630 |
-
)
|
631 |
-
|
632 |
-
# Update axes
|
633 |
-
fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
|
634 |
-
fig.update_xaxes(title_text="Number of Nodes", row=1, col=2)
|
635 |
-
fig.update_xaxes(title_text="Statistics", tickangle=45, row=2, col=1)
|
636 |
-
fig.update_xaxes(title_text="Connectivity Metrics", tickangle=45, row=2, col=2)
|
637 |
-
|
638 |
-
fig.update_yaxes(title_text="Count", row=1, col=1)
|
639 |
-
fig.update_yaxes(title_text="Frequency", row=1, col=2)
|
640 |
-
fig.update_yaxes(title_text="Average Value", row=2, col=1)
|
641 |
-
fig.update_yaxes(title_text="Ratio", row=2, col=2)
|
642 |
-
|
643 |
-
return fig
|
644 |
-
|
645 |
-
except Exception as e:
|
646 |
-
logger.error(f"Dataset stats plot error: {e}")
|
647 |
-
return GraphVisualizer._create_error_figure(f"Dataset stats error: {str(e)}")
|
648 |
-
|
649 |
-
@staticmethod
|
650 |
-
def _create_error_figure(error_message):
|
651 |
-
"""Create an error figure with message"""
|
652 |
-
fig = go.Figure()
|
653 |
-
fig.add_annotation(
|
654 |
-
text=error_message,
|
655 |
-
x=0.5, y=0.5,
|
656 |
-
xref="paper", yref="paper",
|
657 |
-
showarrow=False,
|
658 |
-
font=dict(size=14, color="red"),
|
659 |
-
bgcolor="rgba(255,255,255,0.8)",
|
660 |
-
bordercolor="red",
|
661 |
-
borderwidth=1
|
662 |
-
)
|
663 |
-
fig.update_layout(
|
664 |
-
title="Visualization Error",
|
665 |
-
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
666 |
-
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
667 |
-
plot_bgcolor='white',
|
668 |
-
width=600,
|
669 |
-
height=400
|
670 |
-
)
|
671 |
-
return fig
|
672 |
-
|
673 |
-
@staticmethod
|
674 |
-
def create_comparison_plot(results_dict):
|
675 |
-
"""
|
676 |
-
Create model comparison visualization
|
677 |
-
|
678 |
-
Args:
|
679 |
-
results_dict: Dictionary mapping model names to their results
|
680 |
-
|
681 |
-
Returns:
|
682 |
-
Plotly figure object
|
683 |
-
"""
|
684 |
-
try:
|
685 |
-
if not isinstance(results_dict, dict) or not results_dict:
|
686 |
-
return GraphVisualizer._create_error_figure("No comparison data available")
|
687 |
-
|
688 |
-
# Extract metrics for comparison
|
689 |
-
models = []
|
690 |
-
accuracies = []
|
691 |
-
f1_scores = []
|
692 |
-
losses = []
|
693 |
-
|
694 |
-
for model_name, metrics in results_dict.items():
|
695 |
-
if isinstance(metrics, dict):
|
696 |
-
models.append(model_name)
|
697 |
-
accuracies.append(metrics.get('accuracy', 0))
|
698 |
-
f1_scores.append(metrics.get('f1_macro', 0))
|
699 |
-
losses.append(metrics.get('loss', float('inf')))
|
700 |
-
|
701 |
-
if not models:
|
702 |
-
return GraphVisualizer._create_error_figure("No valid model results to compare")
|
703 |
-
|
704 |
-
# Create comparison figure
|
705 |
-
fig = make_subplots(
|
706 |
-
rows=1, cols=3,
|
707 |
-
subplot_titles=('Accuracy Comparison', 'F1 Score Comparison', 'Loss Comparison')
|
708 |
-
)
|
709 |
-
|
710 |
-
# Accuracy comparison
|
711 |
-
fig.add_trace(
|
712 |
-
go.Bar(
|
713 |
-
x=models,
|
714 |
-
y=accuracies,
|
715 |
-
name='Accuracy',
|
716 |
-
marker_color='lightblue',
|
717 |
-
text=[f'{acc:.3f}' for acc in accuracies],
|
718 |
-
textposition='auto',
|
719 |
-
showlegend=False
|
720 |
-
),
|
721 |
-
row=1, col=1
|
722 |
-
)
|
723 |
-
|
724 |
-
# F1 Score comparison
|
725 |
-
fig.add_trace(
|
726 |
-
go.Bar(
|
727 |
-
x=models,
|
728 |
-
y=f1_scores,
|
729 |
-
name='F1 Score',
|
730 |
-
marker_color='lightgreen',
|
731 |
-
text=[f'{f1:.3f}' for f1 in f1_scores],
|
732 |
-
textposition='auto',
|
733 |
-
showlegend=False
|
734 |
-
),
|
735 |
-
row=1, col=2
|
736 |
-
)
|
737 |
-
|
738 |
-
# Loss comparison (filter out infinite values)
|
739 |
-
finite_losses = [loss if np.isfinite(loss) else 0 for loss in losses]
|
740 |
-
fig.add_trace(
|
741 |
-
go.Bar(
|
742 |
-
x=models,
|
743 |
-
y=finite_losses,
|
744 |
-
name='Loss',
|
745 |
-
marker_color='lightcoral',
|
746 |
-
text=[f'{loss:.3f}' if np.isfinite(loss) else 'inf' for loss in losses],
|
747 |
-
textposition='auto',
|
748 |
-
showlegend=False
|
749 |
-
),
|
750 |
-
row=1, col=3
|
751 |
-
)
|
752 |
-
|
753 |
-
fig.update_layout(
|
754 |
-
title=dict(
|
755 |
-
text='Model Performance Comparison',
|
756 |
-
x=0.5,
|
757 |
-
font=dict(size=18)
|
758 |
-
),
|
759 |
-
height=400
|
760 |
-
)
|
761 |
-
|
762 |
-
# Update axes
|
763 |
-
fig.update_xaxes(tickangle=45, row=1, col=1)
|
764 |
-
fig.update_xaxes(tickangle=45, row=1, col=2)
|
765 |
-
fig.update_xaxes(tickangle=45, row=1, col=3)
|
766 |
-
|
767 |
-
fig.update_yaxes(range=[0, 1], row=1, col=1)
|
768 |
-
fig.update_yaxes(range=[0, 1], row=1, col=2)
|
769 |
-
|
770 |
-
return fig
|
771 |
-
|
772 |
-
except Exception as e:
|
773 |
-
logger.error(f"Comparison plot error: {e}")
|
774 |
-
return GraphVisualizer._create_error_figure(f"Comparison plot error: {str(e)}")
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
class GraphVisualizer:
|
14 |
+
"""Advanced graph visualization utilities"""
|
|
|
|
|
|
|
15 |
|
16 |
@staticmethod
|
17 |
def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0):
|
18 |
+
"""Create interactive graph visualization"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
try:
|
|
|
20 |
if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'):
|
21 |
raise ValueError("Data must have edge_index and num_nodes attributes")
|
22 |
|
|
|
23 |
num_nodes = min(data.num_nodes, max_nodes)
|
24 |
if num_nodes <= 0:
|
25 |
raise ValueError("No nodes to visualize")
|
|
|
29 |
|
30 |
if data.edge_index.size(1) > 0:
|
31 |
edge_list = data.edge_index.t().cpu().numpy()
|
|
|
|
|
32 |
edge_list = edge_list[
|
33 |
(edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
|
34 |
]
|
|
|
36 |
if len(edge_list) > 0:
|
37 |
G.add_edges_from(edge_list)
|
38 |
|
|
|
39 |
G.add_nodes_from(range(num_nodes))
|
40 |
|
41 |
+
# Layout
|
42 |
+
pos = nx.spring_layout(G, seed=42)
|
43 |
|
44 |
+
# Node colors
|
45 |
+
if hasattr(data, 'y') and data.y is not None:
|
46 |
+
node_colors = data.y.cpu().numpy()[:num_nodes]
|
47 |
+
else:
|
48 |
+
node_colors = [0] * num_nodes
|
49 |
|
50 |
# Create edge traces
|
51 |
+
edge_x, edge_y = [], []
|
52 |
+
for edge in G.edges():
|
53 |
+
if edge[0] in pos and edge[1] in pos:
|
54 |
+
x0, y0 = pos[edge[0]]
|
55 |
+
x1, y1 = pos[edge[1]]
|
56 |
+
edge_x.extend([x0, x1, None])
|
57 |
+
edge_y.extend([y0, y1, None])
|
58 |
|
59 |
# Create node traces
|
60 |
+
node_x = [pos[node][0] for node in G.nodes()]
|
61 |
+
node_y = [pos[node][1] for node in G.nodes()]
|
|
|
62 |
|
|
|
63 |
fig = go.Figure()
|
64 |
|
65 |
# Add edges
|
|
|
69 |
line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
|
70 |
hoverinfo='none',
|
71 |
mode='lines',
|
|
|
72 |
showlegend=False
|
73 |
))
|
74 |
|
|
|
76 |
fig.add_trace(go.Scatter(
|
77 |
x=node_x, y=node_y,
|
78 |
mode='markers',
|
|
|
|
|
|
|
79 |
marker=dict(
|
80 |
+
size=8,
|
81 |
color=node_colors,
|
82 |
colorscale='Viridis',
|
83 |
line=dict(width=2, color='white'),
|
84 |
+
opacity=0.8
|
|
|
85 |
),
|
86 |
+
text=[f"Node {i}" for i in range(len(node_x))],
|
87 |
+
hoverinfo='text',
|
88 |
showlegend=False
|
89 |
))
|
90 |
|
|
|
91 |
fig.update_layout(
|
92 |
+
title=f'Graph Visualization ({num_nodes} nodes)',
|
|
|
|
|
|
|
|
|
93 |
showlegend=False,
|
94 |
hovermode='closest',
|
95 |
margin=dict(b=20, l=5, r=5, t=40),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
97 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
98 |
plot_bgcolor='white',
|
|
|
106 |
logger.error(f"Graph visualization error: {e}")
|
107 |
return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}")
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
@staticmethod
|
110 |
def create_metrics_plot(metrics):
|
111 |
+
"""Create comprehensive metrics visualization"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
try:
|
|
|
113 |
metric_names = []
|
114 |
metric_values = []
|
115 |
|
|
|
122 |
if not metric_names:
|
123 |
return GraphVisualizer._create_error_figure("No valid metrics to display")
|
124 |
|
125 |
+
fig = make_subplots(
|
126 |
+
rows=1, cols=2,
|
127 |
+
subplot_titles=('Performance Metrics', 'Metric Radar Chart'),
|
128 |
+
specs=[[{"type": "bar"}, {"type": "polar"}]]
|
129 |
+
)
|
130 |
+
|
131 |
+
colors = px.colors.qualitative.Set3[:len(metric_names)]
|
132 |
+
|
133 |
+
fig.add_trace(
|
134 |
+
go.Bar(
|
135 |
+
x=metric_names,
|
136 |
+
y=metric_values,
|
137 |
+
marker_color=colors,
|
138 |
+
text=[f'{v:.3f}' for v in metric_values],
|
139 |
+
textposition='auto',
|
140 |
+
showlegend=False
|
141 |
+
),
|
142 |
+
row=1, col=1
|
143 |
+
)
|
144 |
+
|
145 |
+
fig.add_trace(
|
146 |
+
go.Scatterpolar(
|
147 |
+
r=metric_values + [metric_values[0]],
|
148 |
+
theta=metric_names + [metric_names[0]],
|
149 |
+
fill='toself',
|
150 |
+
line=dict(color='blue'),
|
151 |
+
marker=dict(size=8),
|
152 |
+
showlegend=False
|
153 |
+
),
|
154 |
+
row=1, col=2
|
155 |
+
)
|
156 |
+
|
157 |
+
fig.update_layout(
|
158 |
+
title='Model Performance Dashboard',
|
159 |
+
height=400,
|
160 |
+
showlegend=False
|
161 |
+
)
|
162 |
+
|
163 |
+
fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
|
164 |
+
fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
|
165 |
+
|
166 |
+
fig.update_polars(
|
167 |
+
radialaxis=dict(range=[0, 1], showticklabels=True),
|
168 |
+
row=1, col=2
|
169 |
+
)
|
170 |
+
|
171 |
+
return fig
|
172 |
+
|
173 |
+
except Exception as e:
|
174 |
+
logger.error(f"Metrics plot error: {e}")
|
175 |
+
return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}")
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def create_training_history_plot(history):
|
179 |
+
"""Create comprehensive training history visualization"""
|
180 |
+
try:
|
181 |
+
if not isinstance(history, dict) or not history:
|
182 |
+
return GraphVisualizer._create_error_figure("No training history available")
|
183 |
+
|
184 |
+
required_keys = ['train_loss', 'train_acc']
|
185 |
+
for key in required_keys:
|
186 |
+
if key not in history or not history[key]:
|
187 |
+
return GraphVisualizer._create_error_figure(f"Missing {key} in training history")
|
188 |
+
|
189 |
+
epochs = list(range(len(history['train_loss'])))
|
190 |
+
|
191 |
+
fig = make_subplots(
|
192 |
+
rows=2, cols=2,
|
193 |
+
subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'),
|
194 |
+
specs=[[{"secondary_y": False}, {"secondary_y": False}],
|
195 |
+
[{"secondary_y": False}, {"secondary_y": False}]]
|
196 |
+
)
|
197 |
+
|
198 |
+
# Training loss
|
199 |
+
fig.add_trace(
|
200 |
+
go.Scatter(
|
201 |
+
x=epochs, y=history['train_loss'],
|
202 |
+
mode='lines', name='Train Loss',
|
203 |
+
line=dict(color='blue', width=2),
|
204 |
+
showlegend=False
|
205 |
+
),
|
206 |
+
row=1, col=1
|
207 |
+
)
|
208 |
+
|
209 |
+
if 'val_loss' in history and history['val_loss']:
|
210 |
+
fig.add_trace(
|
211 |
+
go.Scatter(
|
212 |
+
x=epochs, y=history['val_loss'],
|
213 |
+
mode='lines', name='Val Loss',
|
214 |
+
line=dict(color='red', width=2),
|
215 |
+
showlegend=False
|
216 |
+
),
|
217 |
+
row=1, col=1
|
218 |
+
)
|
219 |
+
|
220 |
+
# Training accuracy
|
221 |
+
fig.add_trace(
|
222 |
+
go.Scatter(
|
223 |
+
x=epochs, y=history['train_acc'],
|
224 |
+
mode='lines', name='Train Acc',
|
225 |
+
line=dict(color='green', width=2),
|
226 |
+
showlegend=False
|
227 |
+
),
|
228 |
+
row=1, col=2
|
229 |
+
)
|
230 |
+
|
231 |
+
if 'val_acc' in history and history['val_acc']:
|
232 |
+
fig.add_trace(
|
233 |
+
go.Scatter(
|
234 |
+
x=epochs, y=history['val_acc'],
|
235 |
+
mode='lines', name='Val Acc',
|
236 |
+
line=dict(color='orange', width=2),
|
237 |
+
showlegend=False
|
238 |
+
),
|
239 |
+
row=1, col=2
|
240 |
+
)
|
241 |
+
|
242 |
+
# Learning rate
|
243 |
+
if 'lr' in history and history['lr']:
|
244 |
+
fig.add_trace(
|
245 |
+
go.Scatter(
|
246 |
+
x=epochs, y=history['lr'],
|
247 |
+
mode='lines', name='Learning Rate',
|
248 |
+
line=dict(color='purple', width=2),
|
249 |
+
showlegend=False
|
250 |
+
),
|
251 |
+
row=2, col=1
|
252 |
+
)
|
253 |
+
|
254 |
+
# Training progress summary
|
255 |
+
final_metrics = {
|
256 |
+
'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0,
|
257 |
+
'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0,
|
258 |
+
}
|
259 |
+
|
260 |
+
if 'val_acc' in history and history['val_acc']:
|
261 |
+
final_metrics['Final Val Acc'] = history['val_acc'][-1]
|
262 |
+
final_metrics['Best Val Acc'] = max(history['val_acc'])
|
263 |
+
|
264 |
+
metric_names = list(final_metrics.keys())
|
265 |
+
metric_values = list(final_metrics.values())
|
266 |
+
|
267 |
+
fig.add_trace(
|
268 |
+
go.Bar(
|
269 |
+
x=metric_names,
|
270 |
+
y=metric_values,
|
271 |
+
marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'],
|
272 |
+
text=[f'{v:.3f}' for v in metric_values],
|
273 |
+
textposition='auto',
|
274 |
+
showlegend=False
|
275 |
+
),
|
276 |
+
row=2, col=2
|
277 |
+
)
|
278 |
+
|
279 |
+
fig.update_layout(
|
280 |
+
title='Training History Dashboard',
|
281 |
+
height=600,
|
282 |
+
showlegend=True
|
283 |
+
)
|
284 |
+
|
285 |
+
fig.update_xaxes(title_text="Epoch", row=1, col=1)
|
286 |
+
fig.update_xaxes(title_text="Epoch", row=1, col=2)
|
287 |
+
fig.update_xaxes(title_text="Epoch", row=2, col=1)
|
288 |
+
fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2)
|
289 |
+
|
290 |
+
fig.update_yaxes(title_text="Loss", row=1, col=1)
|
291 |
+
fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2)
|
292 |
+
fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
|
293 |
+
fig.update_yaxes(title_text="Value", row=2, col=2)
|
294 |
+
|
295 |
+
return fig
|
296 |
+
|
297 |
+
except Exception as e:
|
298 |
+
logger.error(f"Training history plot error: {e}")
|
299 |
+
return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}")
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def _create_error_figure(error_message):
|
303 |
+
"""Create an error figure with message"""
|
304 |
+
fig = go.Figure()
|
305 |
+
fig.add_annotation(
|
306 |
+
text=error_message,
|
307 |
+
x=0.5, y=0.5,
|
308 |
+
xref="paper", yref="paper",
|
309 |
+
showarrow=False,
|
310 |
+
font=dict(size=14, color="red"),
|
311 |
+
bgcolor="rgba(255,255,255,0.8)",
|
312 |
+
bordercolor="red",
|
313 |
+
borderwidth=1
|
314 |
+
)
|
315 |
+
fig.update_layout(
|
316 |
+
title="Visualization Error",
|
317 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
318 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
319 |
+
plot_bgcolor='white',
|
320 |
+
width=600,
|
321 |
+
height=400
|
322 |
+
)
|
323 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|