import plotly.graph_objects as go import plotly.express as px import plotly.figure_factory as ff from plotly.subplots import make_subplots import networkx as nx import torch import numpy as np import pandas as pd class GraphVisualizer: """Advanced graph visualization utilities""" @staticmethod def create_graph_plot(data, max_nodes=500, layout_algorithm='spring'): """Create interactive graph visualization with multiple layout options""" try: # Limit nodes for performance num_nodes = min(data.num_nodes, max_nodes) # Create NetworkX graph G = nx.Graph() edge_list = data.edge_index.t().cpu().numpy() # Filter edges to include only first max_nodes edge_list = edge_list[ (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes) ] if len(edge_list) > 0: G.add_edges_from(edge_list) # Add isolated nodes G.add_nodes_from(range(num_nodes)) # Choose layout algorithm if layout_algorithm == 'spring': if len(G.nodes()) > 100: pos = nx.spring_layout(G, k=0.5, iterations=20) else: pos = nx.spring_layout(G, k=1, iterations=50) elif layout_algorithm == 'circular': pos = nx.circular_layout(G) elif layout_algorithm == 'kamada_kawai': try: pos = nx.kamada_kawai_layout(G) except: pos = nx.spring_layout(G) elif layout_algorithm == 'spectral': try: pos = nx.spectral_layout(G) except: pos = nx.spring_layout(G) else: pos = nx.spring_layout(G) # Node colors and sizes if hasattr(data, 'y') and data.y is not None: node_colors = data.y.cpu().numpy()[:num_nodes] unique_labels = np.unique(node_colors) color_map = px.colors.qualitative.Set3[:len(unique_labels)] else: node_colors = [0] * num_nodes color_map = ['lightblue'] # Node sizes based on degree node_sizes = [] for node in G.nodes(): degree = G.degree(node) node_sizes.append(max(5, min(20, 5 + degree))) # Create edge traces edge_x, edge_y = [], [] edge_info = [] for edge in G.edges(): if edge[0] in pos and edge[1] in pos: x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_info.append(f"Edge: {edge[0]} - {edge[1]}") # Create node traces node_x = [] node_y = [] node_text = [] node_info = [] for node in G.nodes(): if node in pos: x, y = pos[node] node_x.append(x) node_y.append(y) # Node info degree = G.degree(node) label = node_colors[node] if node < len(node_colors) else 0 node_text.append(f"Node {node}") node_info.append(f"Node: {node}
Degree: {degree}
Label: {label}") fig = go.Figure() # Add edges if edge_x: fig.add_trace(go.Scatter( x=edge_x, y=edge_y, line=dict(width=0.8, color='rgba(125,125,125,0.5)'), hoverinfo='none', mode='lines', name='Edges', showlegend=False )) # Add nodes fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers', hoverinfo='text', hovertext=node_info, text=node_text, marker=dict( size=node_sizes, color=node_colors[:len(node_x)], colorscale='Viridis', line=dict(width=2, color='white'), opacity=0.8 ), name='Nodes', showlegend=False )) fig.update_layout( title=dict( text=f'Graph Visualization ({num_nodes} nodes, {len(edge_list)} edges)', x=0.5, font=dict(size=16) ), showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), annotations=[ dict( text=f"Layout: {layout_algorithm.title()}", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002, xanchor='left', yanchor='bottom', font=dict(color="gray", size=10) ) ], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white', width=800, height=600 ) return fig except Exception as e: # Return error plot fig = go.Figure() fig.add_annotation( text=f"Visualization error: {str(e)}", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14, color="red") ) fig.update_layout( title="Graph Visualization Error", xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='white' ) return fig @staticmethod def create_metrics_plot(metrics): """Create comprehensive metrics visualization""" try: # Filter numeric metrics metric_names = [] metric_values = [] for key, value in metrics.items(): if isinstance(value, (int, float)) and key not in ['error', 'loss']: if not (np.isnan(value) or np.isinf(value)): metric_names.append(key.replace('_', ' ').title()) metric_values.append(value) if metric_names: # Create subplots fig = make_subplots( rows=1, cols=2, subplot_titles=('Performance Metrics', 'Metric Comparison'), specs=[[{"type": "bar"}, {"type": "scatter"}]] ) # Bar chart colors = px.colors.qualitative.Set3[:len(metric_names)] fig.add_trace( go.Bar( x=metric_names, y=metric_values, marker_color=colors, text=[f'{v:.3f}' for v in metric_values], textposition='auto', name='Metrics' ), row=1, col=1 ) # Radar chart data fig.add_trace( go.Scatterpolar( r=metric_values, theta=metric_names, fill='toself', name='Performance', line=dict(color='blue') ), row=1, col=2 ) fig.update_layout( title=dict( text='Model Performance Dashboard', x=0.5, font=dict(size=18) ), showlegend=False, height=400 ) # Update bar chart fig.update_xaxes(title_text="Metrics", row=1, col=1) fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1) # Update polar chart fig.update_polars( radialaxis=dict(range=[0, 1], showticklabels=True), row=1, col=2 ) else: fig = go.Figure() fig.add_annotation( text="No valid metrics to display", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14) ) fig.update_layout(title="Metrics Dashboard") return fig except Exception as e: fig = go.Figure() fig.add_annotation( text=f"Metrics plot error: {str(e)}", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14, color="red") ) fig.update_layout(title="Metrics Error") return fig @staticmethod def create_training_history_plot(history): """Create training history visualization""" try: epochs = list(range(len(history['train_loss']))) # Create subplots fig = make_subplots( rows=2, cols=2, subplot_titles=('Training Loss', 'Training Accuracy', 'Learning Rate', 'Loss Comparison'), specs=[[{"secondary_y": False}, {"secondary_y": False}], [{"secondary_y": False}, {"secondary_y": False}]] ) # Training loss fig.add_trace( go.Scatter( x=epochs, y=history['train_loss'], mode='lines', name='Train Loss', line=dict(color='blue', width=2) ), row=1, col=1 ) if 'val_loss' in history: fig.add_trace( go.Scatter( x=epochs, y=history['val_loss'], mode='lines', name='Val Loss', line=dict(color='red', width=2) ), row=1, col=1 ) # Training accuracy fig.add_trace( go.Scatter( x=epochs, y=history['train_acc'], mode='lines', name='Train Acc', line=dict(color='green', width=2) ), row=1, col=2 ) if 'val_acc' in history: fig.add_trace( go.Scatter( x=epochs, y=history['val_acc'], mode='lines', name='Val Acc', line=dict(color='orange', width=2) ), row=1, col=2 ) # Learning rate if 'lr' in history: fig.add_trace( go.Scatter( x=epochs, y=history['lr'], mode='lines', name='Learning Rate', line=dict(color='purple', width=2) ), row=2, col=1 ) # Loss comparison if 'train_loss' in history and 'val_loss' in history: fig.add_trace( go.Scatter( x=history['train_loss'], y=history['val_loss'], mode='markers', name='Train vs Val Loss', marker=dict(color=epochs, colorscale='Viridis', size=8), text=[f'Epoch {i}' for i in epochs], hovertemplate='Train Loss: %{x:.4f}
Val Loss: %{y:.4f}
%{text}' ), row=2, col=2 ) # Add diagonal line min_loss = min(min(history['train_loss']), min(history['val_loss'])) max_loss = max(max(history['train_loss']), max(history['val_loss'])) fig.add_trace( go.Scatter( x=[min_loss, max_loss], y=[min_loss, max_loss], mode='lines', name='Perfect Fit', line=dict(color='gray', dash='dash'), showlegend=False ), row=2, col=2 ) fig.update_layout( title=dict( text='Training History Dashboard', x=0.5, font=dict(size=18) ), height=600, showlegend=True ) # Update axes fig.update_xaxes(title_text="Epoch", row=1, col=1) fig.update_xaxes(title_text="Epoch", row=1, col=2) fig.update_xaxes(title_text="Epoch", row=2, col=1) fig.update_xaxes(title_text="Train Loss", row=2, col=2) fig.update_yaxes(title_text="Loss", row=1, col=1) fig.update_yaxes(title_text="Accuracy", row=1, col=2) fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1) fig.update_yaxes(title_text="Val Loss", row=2, col=2) return fig except Exception as e: fig = go.Figure() fig.add_annotation( text=f"Training history plot error: {str(e)}", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14, color="red") ) return fig @staticmethod def create_dataset_stats_plot(dataset_info): """Create dataset statistics visualization""" try: # Prepare data stats_data = [] for key, value in dataset_info.items(): if isinstance(value, (int, float)) and not np.isnan(value): stats_data.append({ 'Metric': key.replace('_', ' ').title(), 'Value': value }) if not stats_data: raise ValueError("No valid statistics to display") df = pd.DataFrame(stats_data) # Create subplots fig = make_subplots( rows=1, cols=2, subplot_titles=('Dataset Overview', 'Graph Size Distribution'), specs=[[{"type": "bar"}, {"type": "box"}]] ) # Bar chart of statistics fig.add_trace( go.Bar( x=df['Metric'], y=df['Value'], marker_color=px.colors.qualitative.Pastel1, text=df['Value'], texttemplate='%{text:,.0f}', textposition='auto' ), row=1, col=1 ) # Box plot for size distribution (if multiple graphs) if dataset_info.get('num_graphs', 1) > 1: # Simulate distribution based on min/max/avg avg_nodes = dataset_info.get('avg_nodes', 100) min_nodes = dataset_info.get('min_nodes', avg_nodes * 0.5) max_nodes = dataset_info.get('max_nodes', avg_nodes * 1.5) # Generate synthetic distribution np.random.seed(42) node_dist = np.random.normal(avg_nodes, (max_nodes - min_nodes) / 4, 100) node_dist = np.clip(node_dist, min_nodes, max_nodes) fig.add_trace( go.Box( y=node_dist, name='Node Count', marker_color='lightblue' ), row=1, col=2 ) else: # Single graph - show as point fig.add_trace( go.Scatter( x=['Nodes'], y=[dataset_info.get('avg_nodes', 0)], mode='markers', marker=dict(size=20, color='blue'), name='Node Count' ), row=1, col=2 ) fig.update_layout( title=dict( text='Dataset Statistics Dashboard', x=0.5, font=dict(size=16) ), height=400, showlegend=False ) # Update axes fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1) fig.update_yaxes(title_text="Count", row=1, col=1) fig.update_yaxes(title_text="Number of Nodes", row=1, col=2) return fig except Exception as e: fig = go.Figure() fig.add_annotation( text=f"Dataset stats error: {str(e)}", x=0.5, y=0.5, xref="paper", yref="paper", showarrow=False, font=dict(size=14, color="red") ) return fig