File size: 5,174 Bytes
58746a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import plotly.graph_objects as go
import plotly.express as px
import networkx as nx
import torch
import numpy as np

class GraphVisualizer:
    """Graph visualization utilities"""
    
    @staticmethod
    def create_graph_plot(data, max_nodes=500):
        """Create interactive graph visualization"""
        try:
            # Limit nodes for performance
            num_nodes = min(data.num_nodes, max_nodes)
            
            # Create NetworkX graph
            G = nx.Graph()
            edge_list = data.edge_index.t().cpu().numpy()
            
            # Filter edges to include only first max_nodes
            edge_list = edge_list[
                (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
            ]
            
            if len(edge_list) > 0:
                G.add_edges_from(edge_list)
            
            # Add isolated nodes
            G.add_nodes_from(range(num_nodes))
            
            # Layout
            if len(G.nodes()) > 100:
                pos = nx.spring_layout(G, k=0.5, iterations=20)
            else:
                pos = nx.spring_layout(G, k=1, iterations=50)
            
            # Node colors
            if hasattr(data, 'y') and data.y is not None:
                node_colors = data.y.cpu().numpy()[:num_nodes]
            else:
                node_colors = [0] * num_nodes
            
            # Create edge traces
            edge_x, edge_y = [], []
            for edge in G.edges():
                if edge[0] in pos and edge[1] in pos:
                    x0, y0 = pos[edge[0]]
                    x1, y1 = pos[edge[1]]
                    edge_x.extend([x0, x1, None])
                    edge_y.extend([y0, y1, None])
            
            # Create node traces
            node_x = [pos[node][0] for node in G.nodes() if node in pos]
            node_y = [pos[node][1] for node in G.nodes() if node in pos]
            
            fig = go.Figure()
            
            # Add edges
            if edge_x:
                fig.add_trace(go.Scatter(
                    x=edge_x, y=edge_y,
                    line=dict(width=0.5, color='#888'),
                    hoverinfo='none',
                    mode='lines',
                    name='Edges'
                ))
            
            # Add nodes
            fig.add_trace(go.Scatter(
                x=node_x, y=node_y,
                mode='markers',
                hoverinfo='text',
                text=[f'Node {i}' for i in range(len(node_x))],
                marker=dict(
                    size=8,
                    color=node_colors[:len(node_x)],
                    colorscale='Viridis',
                    line=dict(width=1)
                ),
                name='Nodes'
            ))
            
            fig.update_layout(
                title=f'Graph Visualization ({num_nodes} nodes)',
                showlegend=False,
                hovermode='closest',
                margin=dict(b=20, l=5, r=5, t=40),
                xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                plot_bgcolor='white'
            )
            
            return fig
            
        except Exception as e:
            # Return error plot
            fig = go.Figure()
            fig.add_annotation(
                text=f"Visualization error: {str(e)}",
                x=0.5, y=0.5,
                xref="paper", yref="paper",
                showarrow=False
            )
            return fig
    
    @staticmethod
    def create_metrics_plot(metrics):
        """Create metrics visualization"""
        try:
            metric_names = []
            metric_values = []
            
            for key, value in metrics.items():
                if isinstance(value, (int, float)) and key != 'error':
                    metric_names.append(key.replace('_', ' ').title())
                    metric_values.append(value)
            
            if metric_names:
                fig = go.Figure(data=[
                    go.Bar(
                        x=metric_names,
                        y=metric_values,
                        marker_color='lightblue'
                    )
                ])
                
                fig.update_layout(
                    title='Model Performance Metrics',
                    xaxis_title='Metric',
                    yaxis_title='Value',
                    yaxis=dict(range=[0, 1])
                )
            else:
                fig = go.Figure()
                fig.add_annotation(
                    text="No metrics to display",
                    x=0.5, y=0.5,
                    xref="paper", yref="paper",
                    showarrow=False
                )
                
            return fig
            
        except Exception as e:
            fig = go.Figure()
            fig.add_annotation(
                text=f"Metrics plot error: {str(e)}",
                x=0.5, y=0.5,
                xref="paper", yref="paper",
                showarrow=False
            )
            return fig