neuralworm commited on
Commit
81dcb15
1 Parent(s): eafed86

plotly instead of matplotlib

Browse files
Files changed (2) hide show
  1. app.py +10 -5
  2. psychohistory.py +35 -42
app.py CHANGED
@@ -15,14 +15,19 @@ with gr.Blocks(title="PSYCHOHISTORY") as app:
15
  outputs=mem_results
16
  )
17
 
18
- with gr.Row():
19
- img_output = gr.Image(label="Graph Visualization", type="filepath") # Add an Image component
20
 
21
- # Trigger graph generation after JSON is generated
22
- mem_results.change(
 
 
 
 
 
23
  psychohistory.main,
24
  inputs=[mem_results],
25
- outputs=img_output
26
  )
27
 
28
  if __name__ == "__main__":
 
15
  outputs=mem_results
16
  )
17
 
18
+ # with gr.Row():
19
+ # img_output = gr.Image(label="Graph Visualization", type="filepath") # Add an Image component
20
 
21
+ # # Trigger graph generation after JSON is generated
22
+ # mem_results.change(
23
+ # psychohistory.main,
24
+ # inputs=[mem_results],
25
+ # outputs=img_output
26
+ # )
27
+ mem_results.change(
28
  psychohistory.main,
29
  inputs=[mem_results],
30
+ outputs=None
31
  )
32
 
33
  if __name__ == "__main__":
psychohistory.py CHANGED
@@ -1,4 +1,4 @@
1
- import matplotlib.pyplot as plt
2
  from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
4
  import numpy as np
@@ -95,38 +95,34 @@ def find_paths(G):
95
 
96
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
97
 
98
- def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
99
- """Draws a specific path in 3D."""
 
100
  H = G.subgraph(path).copy()
101
  pos = nx.get_node_attributes(G, 'pos')
102
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
103
 
104
- fig = plt.figure(figsize=(16, 12))
105
- ax = fig.add_subplot(111, projection='3d')
106
-
107
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
108
- ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
 
 
109
 
 
110
  for edge in H.edges():
111
  x_start, y_start, z_start = pos[edge[0]]
112
  x_end, y_end, z_end = pos[edge[1]]
113
- ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
114
-
115
- for node, (x, y, z) in pos.items():
116
- if node in path:
117
- ax.text(x, y, z, str(node), fontsize=12, color='black')
118
 
119
- ax.set_xlabel('Time (weeks)')
120
- ax.set_ylabel('Event Probability')
121
- ax.set_zlabel('Event Number')
122
- ax.set_title('3D Event Tree - Path')
123
 
124
- plt.savefig(filename, bbox_inches='tight')
125
- plt.close()
126
 
127
-
128
- def draw_global_tree_3d(G, filename='global_tree.png'):
129
- """Draws the entire graph in 3D."""
130
  pos = nx.get_node_attributes(G, 'pos')
131
  labels = nx.get_node_attributes(G, 'label')
132
 
@@ -135,35 +131,32 @@ def draw_global_tree_3d(G, filename='global_tree.png'):
135
  return
136
 
137
  x_vals, y_vals, z_vals = zip(*pos.values())
138
- fig = plt.figure(figsize=(16, 12))
139
- ax = fig.add_subplot(111, projection='3d')
140
 
141
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
142
- ax.scatter(x_vals, y_vals, z_vals, c=node_colors, s=700, edgecolors='black', alpha=0.7)
 
 
143
 
 
144
  for edge in G.edges():
145
  x_start, y_start, z_start = pos[edge[0]]
146
  x_end, y_end, z_end = pos[edge[1]]
147
- ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
148
-
149
- for node, (x, y, z) in pos.items():
150
- label = labels.get(node, f"{node}")
151
- ax.text(x, y, z, label, fontsize=12, color='black')
152
-
153
- ax.set_xlabel('Time')
154
- ax.set_ylabel('Probability')
155
- ax.set_zlabel('Event Number')
156
- ax.set_title('3D Event Tree')
157
 
158
- plt.savefig(filename, bbox_inches='tight')
159
- plt.close()
 
 
160
 
161
 
162
  def main(json_data):
163
  G = nx.DiGraph()
164
  build_graph_from_json(json_data, G) # Build graph from the provided JSON data
165
 
166
- draw_global_tree_3d(G, filename='global_tree.png')
 
167
 
168
  best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
169
 
@@ -181,12 +174,12 @@ def main(json_data):
181
  print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
182
 
183
  if best_path:
184
- draw_path_3d(G, best_path, 'best_path.png', 'blue')
185
  if worst_path:
186
- draw_path_3d(G, worst_path, 'worst_path.png', 'red')
187
  if longest_path:
188
- draw_path_3d(G, longest_path, 'longest_duration_path.png', 'green')
189
  if shortest_path:
190
- draw_path_3d(G, shortest_path, 'shortest_duration_path.png', 'purple')
191
 
192
- return 'global_tree.png' # Return the filename of the global tree
 
1
+ import plotly.graph_objects as go # Import Plotly for interactive plots
2
  from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
4
  import numpy as np
 
95
 
96
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
97
 
98
+
99
+ def draw_path_3d_interactive(G, path, highlight_color='blue'):
100
+ """Draws a specific path in 3D using Plotly for interactivity."""
101
  H = G.subgraph(path).copy()
102
  pos = nx.get_node_attributes(G, 'pos')
103
  x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
104
 
 
 
 
105
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
106
+ node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
107
+ marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
108
+ text=list(map(str, path)), textposition='top center', hoverinfo='text')
109
 
110
+ edge_traces = []
111
  for edge in H.edges():
112
  x_start, y_start, z_start = pos[edge[0]]
113
  x_end, y_end, z_end = pos[edge[1]]
114
+ edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
115
+ mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
116
+ edge_traces.append(edge_trace)
 
 
117
 
118
+ layout = go.Layout(scene=dict(xaxis_title='Time (weeks)', yaxis_title='Event Probability', zaxis_title='Event Number'),
119
+ title='3D Event Tree - Path')
120
+ fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
121
+ fig.show()
122
 
 
 
123
 
124
+ def draw_global_tree_3d_interactive(G):
125
+ """Draws the entire graph in 3D using Plotly for interactivity."""
 
126
  pos = nx.get_node_attributes(G, 'pos')
127
  labels = nx.get_node_attributes(G, 'label')
128
 
 
131
  return
132
 
133
  x_vals, y_vals, z_vals = zip(*pos.values())
 
 
134
 
135
  node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
136
+ node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
137
+ marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
138
+ text=list(labels.values()), textposition='top center', hoverinfo='text')
139
 
140
+ edge_traces = []
141
  for edge in G.edges():
142
  x_start, y_start, z_start = pos[edge[0]]
143
  x_end, y_end, z_end = pos[edge[1]]
144
+ edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
145
+ mode='lines', line=dict(width=2, color='gray'), hoverinfo='none')
146
+ edge_traces.append(edge_trace)
 
 
 
 
 
 
 
147
 
148
+ layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
149
+ title='3D Event Tree')
150
+ fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
151
+ fig.show()
152
 
153
 
154
  def main(json_data):
155
  G = nx.DiGraph()
156
  build_graph_from_json(json_data, G) # Build graph from the provided JSON data
157
 
158
+ # Draw the interactive graph using Plotly
159
+ draw_global_tree_3d_interactive(G)
160
 
161
  best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
162
 
 
174
  print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
175
 
176
  if best_path:
177
+ draw_path_3d_interactive(G, best_path, 'blue')
178
  if worst_path:
179
+ draw_path_3d_interactive(G, worst_path, 'red')
180
  if longest_path:
181
+ draw_path_3d_interactive(G, longest_path, 'green')
182
  if shortest_path:
183
+ draw_path_3d_interactive(G, shortest_path, 'purple')
184
 
185
+ return None # No need to return a filename for interactive plot