kfoughali commited on
Commit
58746a0
·
verified ·
1 Parent(s): 850d736

Create visualization.py

Browse files
Files changed (1) hide show
  1. utils/visualization.py +153 -0
utils/visualization.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import plotly.express as px
3
+ import networkx as nx
4
+ import torch
5
+ import numpy as np
6
+
7
+ class GraphVisualizer:
8
+ """Graph visualization utilities"""
9
+
10
+ @staticmethod
11
+ def create_graph_plot(data, max_nodes=500):
12
+ """Create interactive graph visualization"""
13
+ try:
14
+ # Limit nodes for performance
15
+ num_nodes = min(data.num_nodes, max_nodes)
16
+
17
+ # Create NetworkX graph
18
+ G = nx.Graph()
19
+ edge_list = data.edge_index.t().cpu().numpy()
20
+
21
+ # Filter edges to include only first max_nodes
22
+ edge_list = edge_list[
23
+ (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes)
24
+ ]
25
+
26
+ if len(edge_list) > 0:
27
+ G.add_edges_from(edge_list)
28
+
29
+ # Add isolated nodes
30
+ G.add_nodes_from(range(num_nodes))
31
+
32
+ # Layout
33
+ if len(G.nodes()) > 100:
34
+ pos = nx.spring_layout(G, k=0.5, iterations=20)
35
+ else:
36
+ pos = nx.spring_layout(G, k=1, iterations=50)
37
+
38
+ # Node colors
39
+ if hasattr(data, 'y') and data.y is not None:
40
+ node_colors = data.y.cpu().numpy()[:num_nodes]
41
+ else:
42
+ node_colors = [0] * num_nodes
43
+
44
+ # Create edge traces
45
+ edge_x, edge_y = [], []
46
+ for edge in G.edges():
47
+ if edge[0] in pos and edge[1] in pos:
48
+ x0, y0 = pos[edge[0]]
49
+ x1, y1 = pos[edge[1]]
50
+ edge_x.extend([x0, x1, None])
51
+ edge_y.extend([y0, y1, None])
52
+
53
+ # Create node traces
54
+ node_x = [pos[node][0] for node in G.nodes() if node in pos]
55
+ node_y = [pos[node][1] for node in G.nodes() if node in pos]
56
+
57
+ fig = go.Figure()
58
+
59
+ # Add edges
60
+ if edge_x:
61
+ fig.add_trace(go.Scatter(
62
+ x=edge_x, y=edge_y,
63
+ line=dict(width=0.5, color='#888'),
64
+ hoverinfo='none',
65
+ mode='lines',
66
+ name='Edges'
67
+ ))
68
+
69
+ # Add nodes
70
+ fig.add_trace(go.Scatter(
71
+ x=node_x, y=node_y,
72
+ mode='markers',
73
+ hoverinfo='text',
74
+ text=[f'Node {i}' for i in range(len(node_x))],
75
+ marker=dict(
76
+ size=8,
77
+ color=node_colors[:len(node_x)],
78
+ colorscale='Viridis',
79
+ line=dict(width=1)
80
+ ),
81
+ name='Nodes'
82
+ ))
83
+
84
+ fig.update_layout(
85
+ title=f'Graph Visualization ({num_nodes} nodes)',
86
+ showlegend=False,
87
+ hovermode='closest',
88
+ margin=dict(b=20, l=5, r=5, t=40),
89
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
90
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
91
+ plot_bgcolor='white'
92
+ )
93
+
94
+ return fig
95
+
96
+ except Exception as e:
97
+ # Return error plot
98
+ fig = go.Figure()
99
+ fig.add_annotation(
100
+ text=f"Visualization error: {str(e)}",
101
+ x=0.5, y=0.5,
102
+ xref="paper", yref="paper",
103
+ showarrow=False
104
+ )
105
+ return fig
106
+
107
+ @staticmethod
108
+ def create_metrics_plot(metrics):
109
+ """Create metrics visualization"""
110
+ try:
111
+ metric_names = []
112
+ metric_values = []
113
+
114
+ for key, value in metrics.items():
115
+ if isinstance(value, (int, float)) and key != 'error':
116
+ metric_names.append(key.replace('_', ' ').title())
117
+ metric_values.append(value)
118
+
119
+ if metric_names:
120
+ fig = go.Figure(data=[
121
+ go.Bar(
122
+ x=metric_names,
123
+ y=metric_values,
124
+ marker_color='lightblue'
125
+ )
126
+ ])
127
+
128
+ fig.update_layout(
129
+ title='Model Performance Metrics',
130
+ xaxis_title='Metric',
131
+ yaxis_title='Value',
132
+ yaxis=dict(range=[0, 1])
133
+ )
134
+ else:
135
+ fig = go.Figure()
136
+ fig.add_annotation(
137
+ text="No metrics to display",
138
+ x=0.5, y=0.5,
139
+ xref="paper", yref="paper",
140
+ showarrow=False
141
+ )
142
+
143
+ return fig
144
+
145
+ except Exception as e:
146
+ fig = go.Figure()
147
+ fig.add_annotation(
148
+ text=f"Metrics plot error: {str(e)}",
149
+ x=0.5, y=0.5,
150
+ xref="paper", yref="paper",
151
+ showarrow=False
152
+ )
153
+ return fig