import plotly.graph_objects as go
import plotly.express as px
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
import networkx as nx
import torch
import numpy as np
import pandas as pd
class GraphVisualizer:
"""Advanced graph visualization utilities"""
@staticmethod
def create_graph_plot(data, max_nodes=500, layout_algorithm='spring'):
"""Create interactive graph visualization with multiple layout options"""
try:
# Limit nodes for performance
num_nodes = min(data.num_nodes, max_nodes)
# Create NetworkX graph
G = nx.Graph()
edge_list = data.edge_index.t().cpu().numpy()
# Filter edges to include only first max_nodes
edge_list = edge_list[
(edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
]
if len(edge_list) > 0:
G.add_edges_from(edge_list)
# Add isolated nodes
G.add_nodes_from(range(num_nodes))
# Choose layout algorithm
if layout_algorithm == 'spring':
if len(G.nodes()) > 100:
pos = nx.spring_layout(G, k=0.5, iterations=20)
else:
pos = nx.spring_layout(G, k=1, iterations=50)
elif layout_algorithm == 'circular':
pos = nx.circular_layout(G)
elif layout_algorithm == 'kamada_kawai':
try:
pos = nx.kamada_kawai_layout(G)
except:
pos = nx.spring_layout(G)
elif layout_algorithm == 'spectral':
try:
pos = nx.spectral_layout(G)
except:
pos = nx.spring_layout(G)
else:
pos = nx.spring_layout(G)
# Node colors and sizes
if hasattr(data, 'y') and data.y is not None:
node_colors = data.y.cpu().numpy()[:num_nodes]
unique_labels = np.unique(node_colors)
color_map = px.colors.qualitative.Set3[:len(unique_labels)]
else:
node_colors = [0] * num_nodes
color_map = ['lightblue']
# Node sizes based on degree
node_sizes = []
for node in G.nodes():
degree = G.degree(node)
node_sizes.append(max(5, min(20, 5 + degree)))
# Create edge traces
edge_x, edge_y = [], []
edge_info = []
for edge in G.edges():
if edge[0] in pos and edge[1] in pos:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_info.append(f"Edge: {edge[0]} - {edge[1]}")
# Create node traces
node_x = []
node_y = []
node_text = []
node_info = []
for node in G.nodes():
if node in pos:
x, y = pos[node]
node_x.append(x)
node_y.append(y)
# Node info
degree = G.degree(node)
label = node_colors[node] if node < len(node_colors) else 0
node_text.append(f"Node {node}")
node_info.append(f"Node: {node}
Degree: {degree}
Label: {label}")
fig = go.Figure()
# Add edges
if edge_x:
fig.add_trace(go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.8, color='rgba(125,125,125,0.5)'),
hoverinfo='none',
mode='lines',
name='Edges',
showlegend=False
))
# Add nodes
fig.add_trace(go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
hovertext=node_info,
text=node_text,
marker=dict(
size=node_sizes,
color=node_colors[:len(node_x)],
colorscale='Viridis',
line=dict(width=2, color='white'),
opacity=0.8
),
name='Nodes',
showlegend=False
))
fig.update_layout(
title=dict(
text=f'Graph Visualization ({num_nodes} nodes, {len(edge_list)} edges)',
x=0.5,
font=dict(size=16)
),
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
annotations=[
dict(
text=f"Layout: {layout_algorithm.title()}",
showarrow=False,
xref="paper", yref="paper",
x=0.005, y=-0.002,
xanchor='left', yanchor='bottom',
font=dict(color="gray", size=10)
)
],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white',
width=800,
height=600
)
return fig
except Exception as e:
# Return error plot
fig = go.Figure()
fig.add_annotation(
text=f"Visualization error: {str(e)}",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14, color="red")
)
fig.update_layout(
title="Graph Visualization Error",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white'
)
return fig
@staticmethod
def create_metrics_plot(metrics):
"""Create comprehensive metrics visualization"""
try:
# Filter numeric metrics
metric_names = []
metric_values = []
for key, value in metrics.items():
if isinstance(value, (int, float)) and key not in ['error', 'loss']:
if not (np.isnan(value) or np.isinf(value)):
metric_names.append(key.replace('_', ' ').title())
metric_values.append(value)
if metric_names:
# Create subplots
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Performance Metrics', 'Metric Comparison'),
specs=[[{"type": "bar"}, {"type": "scatter"}]]
)
# Bar chart
colors = px.colors.qualitative.Set3[:len(metric_names)]
fig.add_trace(
go.Bar(
x=metric_names,
y=metric_values,
marker_color=colors,
text=[f'{v:.3f}' for v in metric_values],
textposition='auto',
name='Metrics'
),
row=1, col=1
)
# Radar chart data
fig.add_trace(
go.Scatterpolar(
r=metric_values,
theta=metric_names,
fill='toself',
name='Performance',
line=dict(color='blue')
),
row=1, col=2
)
fig.update_layout(
title=dict(
text='Model Performance Dashboard',
x=0.5,
font=dict(size=18)
),
showlegend=False,
height=400
)
# Update bar chart
fig.update_xaxes(title_text="Metrics", row=1, col=1)
fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1)
# Update polar chart
fig.update_polars(
radialaxis=dict(range=[0, 1], showticklabels=True),
row=1, col=2
)
else:
fig = go.Figure()
fig.add_annotation(
text="No valid metrics to display",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14)
)
fig.update_layout(title="Metrics Dashboard")
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Metrics plot error: {str(e)}",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14, color="red")
)
fig.update_layout(title="Metrics Error")
return fig
@staticmethod
def create_training_history_plot(history):
"""Create training history visualization"""
try:
epochs = list(range(len(history['train_loss'])))
# Create subplots
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('Training Loss', 'Training Accuracy', 'Learning Rate', 'Loss Comparison'),
specs=[[{"secondary_y": False}, {"secondary_y": False}],
[{"secondary_y": False}, {"secondary_y": False}]]
)
# Training loss
fig.add_trace(
go.Scatter(
x=epochs, y=history['train_loss'],
mode='lines', name='Train Loss',
line=dict(color='blue', width=2)
),
row=1, col=1
)
if 'val_loss' in history:
fig.add_trace(
go.Scatter(
x=epochs, y=history['val_loss'],
mode='lines', name='Val Loss',
line=dict(color='red', width=2)
),
row=1, col=1
)
# Training accuracy
fig.add_trace(
go.Scatter(
x=epochs, y=history['train_acc'],
mode='lines', name='Train Acc',
line=dict(color='green', width=2)
),
row=1, col=2
)
if 'val_acc' in history:
fig.add_trace(
go.Scatter(
x=epochs, y=history['val_acc'],
mode='lines', name='Val Acc',
line=dict(color='orange', width=2)
),
row=1, col=2
)
# Learning rate
if 'lr' in history:
fig.add_trace(
go.Scatter(
x=epochs, y=history['lr'],
mode='lines', name='Learning Rate',
line=dict(color='purple', width=2)
),
row=2, col=1
)
# Loss comparison
if 'train_loss' in history and 'val_loss' in history:
fig.add_trace(
go.Scatter(
x=history['train_loss'], y=history['val_loss'],
mode='markers', name='Train vs Val Loss',
marker=dict(color=epochs, colorscale='Viridis', size=8),
text=[f'Epoch {i}' for i in epochs],
hovertemplate='Train Loss: %{x:.4f}
Val Loss: %{y:.4f}
%{text}'
),
row=2, col=2
)
# Add diagonal line
min_loss = min(min(history['train_loss']), min(history['val_loss']))
max_loss = max(max(history['train_loss']), max(history['val_loss']))
fig.add_trace(
go.Scatter(
x=[min_loss, max_loss], y=[min_loss, max_loss],
mode='lines', name='Perfect Fit',
line=dict(color='gray', dash='dash'),
showlegend=False
),
row=2, col=2
)
fig.update_layout(
title=dict(
text='Training History Dashboard',
x=0.5,
font=dict(size=18)
),
height=600,
showlegend=True
)
# Update axes
fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
fig.update_xaxes(title_text="Epoch", row=2, col=1)
fig.update_xaxes(title_text="Train Loss", row=2, col=2)
fig.update_yaxes(title_text="Loss", row=1, col=1)
fig.update_yaxes(title_text="Accuracy", row=1, col=2)
fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1)
fig.update_yaxes(title_text="Val Loss", row=2, col=2)
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Training history plot error: {str(e)}",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14, color="red")
)
return fig
@staticmethod
def create_dataset_stats_plot(dataset_info):
"""Create dataset statistics visualization"""
try:
# Prepare data
stats_data = []
for key, value in dataset_info.items():
if isinstance(value, (int, float)) and not np.isnan(value):
stats_data.append({
'Metric': key.replace('_', ' ').title(),
'Value': value
})
if not stats_data:
raise ValueError("No valid statistics to display")
df = pd.DataFrame(stats_data)
# Create subplots
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Dataset Overview', 'Graph Size Distribution'),
specs=[[{"type": "bar"}, {"type": "box"}]]
)
# Bar chart of statistics
fig.add_trace(
go.Bar(
x=df['Metric'],
y=df['Value'],
marker_color=px.colors.qualitative.Pastel1,
text=df['Value'],
texttemplate='%{text:,.0f}',
textposition='auto'
),
row=1, col=1
)
# Box plot for size distribution (if multiple graphs)
if dataset_info.get('num_graphs', 1) > 1:
# Simulate distribution based on min/max/avg
avg_nodes = dataset_info.get('avg_nodes', 100)
min_nodes = dataset_info.get('min_nodes', avg_nodes * 0.5)
max_nodes = dataset_info.get('max_nodes', avg_nodes * 1.5)
# Generate synthetic distribution
np.random.seed(42)
node_dist = np.random.normal(avg_nodes, (max_nodes - min_nodes) / 4, 100)
node_dist = np.clip(node_dist, min_nodes, max_nodes)
fig.add_trace(
go.Box(
y=node_dist,
name='Node Count',
marker_color='lightblue'
),
row=1, col=2
)
else:
# Single graph - show as point
fig.add_trace(
go.Scatter(
x=['Nodes'],
y=[dataset_info.get('avg_nodes', 0)],
mode='markers',
marker=dict(size=20, color='blue'),
name='Node Count'
),
row=1, col=2
)
fig.update_layout(
title=dict(
text='Dataset Statistics Dashboard',
x=0.5,
font=dict(size=16)
),
height=400,
showlegend=False
)
# Update axes
fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1)
fig.update_yaxes(title_text="Count", row=1, col=1)
fig.update_yaxes(title_text="Number of Nodes", row=1, col=2)
return fig
except Exception as e:
fig = go.Figure()
fig.add_annotation(
text=f"Dataset stats error: {str(e)}",
x=0.5, y=0.5,
xref="paper", yref="paper",
showarrow=False,
font=dict(size=14, color="red")
)
return fig