neuralworm commited on
Commit
42a21e8
1 Parent(s): 82f89ae

plotly instead of matplotlib

Browse files
Files changed (2) hide show
  1. app.py +3 -12
  2. psychohistory.py +19 -60
app.py CHANGED
@@ -9,25 +9,16 @@ with gr.Blocks(title="PSYCHOHISTORY") as app:
9
  btn_search = gr.Button("Look", scale=1)
10
  with gr.Row():
11
  mem_results = gr.JSON(label="Results")
 
12
  btn_search.click(
13
  gen.generate,
14
  inputs=[txt_search],
15
  outputs=mem_results
16
  )
17
-
18
- # with gr.Row():
19
- # img_output = gr.Image(label="Graph Visualization", type="filepath") # Add an Image component
20
-
21
- # # Trigger graph generation after JSON is generated
22
- # mem_results.change(
23
- # psychohistory.main,
24
- # inputs=[mem_results],
25
- # outputs=img_output
26
- # )
27
- mem_results.change(
28
  psychohistory.main,
29
  inputs=[mem_results],
30
- outputs=None
31
  )
32
 
33
  if __name__ == "__main__":
 
9
  btn_search = gr.Button("Look", scale=1)
10
  with gr.Row():
11
  mem_results = gr.JSON(label="Results")
12
+ html_output = gr.HTML(label="Graph Visualization") # Use HTML component
13
  btn_search.click(
14
  gen.generate,
15
  inputs=[txt_search],
16
  outputs=mem_results
17
  )
18
+ mem_results.change(
 
 
 
 
 
 
 
 
 
 
19
  psychohistory.main,
20
  inputs=[mem_results],
21
+ outputs=html_output # Output to the HTML component
22
  )
23
 
24
  if __name__ == "__main__":
psychohistory.py CHANGED
@@ -1,5 +1,5 @@
1
  import plotly.graph_objects as go # Import Plotly for interactive plots
2
- from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
4
  import numpy as np
5
  import json
@@ -32,8 +32,6 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
32
  return node_count_per_depth
33
 
34
 
35
-
36
-
37
  def build_graph_from_json(json_data, G):
38
  """Builds a graph from JSON data, handling subevents recursively."""
39
 
@@ -59,7 +57,6 @@ def build_graph_from_json(json_data, G):
59
  add_event(None, event_data, 0) # Add each event as a root node
60
 
61
 
62
-
63
  def find_paths(G):
64
  """Finds paths with highest/lowest probability and longest/shortest durations."""
65
  best_path, worst_path = None, None
@@ -95,40 +92,14 @@ def find_paths(G):
95
 
96
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
97
 
98
-
99
- def draw_path_3d_interactive(G, path, highlight_color='blue'):
100
- """Draws a specific path in 3D using Plotly for interactivity."""
101
- H = G.subgraph(path).copy()
102
- pos = nx.get_node_attributes(G, 'pos')
103
- x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
104
-
105
- node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
106
- node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
107
- marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
108
- text=list(map(str, path)), textposition='top center', hoverinfo='text')
109
-
110
- edge_traces = []
111
- for edge in H.edges():
112
- x_start, y_start, z_start = pos[edge[0]]
113
- x_end, y_end, z_end = pos[edge[1]]
114
- edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
115
- mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
116
- edge_traces.append(edge_trace)
117
-
118
- layout = go.Layout(scene=dict(xaxis_title='Time (weeks)', yaxis_title='Event Probability', zaxis_title='Event Number'),
119
- title='3D Event Tree - Path')
120
- fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
121
- fig.show()
122
-
123
-
124
- def draw_global_tree_3d_interactive(G):
125
- """Draws the entire graph in 3D using Plotly for interactivity."""
126
  pos = nx.get_node_attributes(G, 'pos')
127
  labels = nx.get_node_attributes(G, 'label')
128
 
129
  if not pos:
130
  print("Graph is empty. No nodes to visualize.")
131
- return
132
 
133
  x_vals, y_vals, z_vals = zip(*pos.values())
134
 
@@ -142,44 +113,32 @@ def draw_global_tree_3d_interactive(G):
142
  x_start, y_start, z_start = pos[edge[0]]
143
  x_end, y_end, z_end = pos[edge[1]]
144
  edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
145
- mode='lines', line=dict(width=2, color='gray'), hoverinfo='none')
146
  edge_traces.append(edge_trace)
147
 
148
  layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
149
- title='3D Event Tree')
150
  fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
151
- fig.show()
152
 
 
 
 
153
 
154
  def main(json_data):
155
  G = nx.DiGraph()
156
- build_graph_from_json(json_data, G) # Build graph from the provided JSON data
157
-
158
- # Draw the interactive graph using Plotly
159
- draw_global_tree_3d_interactive(G)
160
 
161
- best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path = find_paths(G)
 
162
 
163
- if best_path:
164
- print(f"\nPath with the highest average probability: {' -> '.join(map(str, best_path))}")
165
- print(f"Average probability: {best_mean_prob:.2f}")
166
- if worst_path:
167
- print(f"\nPath with the lowest average probability: {' -> '.join(map(str, worst_path))}")
168
- print(f"Average probability: {worst_mean_prob:.2f}")
169
- if longest_path:
170
- print(f"\nPath with the longest duration: {' -> '.join(map(str, longest_path))}")
171
- print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_path) - min(G.nodes[node]['pos'][0] for node in longest_path):.2f}")
172
- if shortest_path:
173
- print(f"\nPath with the shortest duration: {' -> '.join(map(str, shortest_path))}")
174
- print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
175
 
176
  if best_path:
177
- draw_path_3d_interactive(G, best_path, 'blue')
 
178
  if worst_path:
179
- draw_path_3d_interactive(G, worst_path, 'red')
180
- if longest_path:
181
- draw_path_3d_interactive(G, longest_path, 'green')
182
- if shortest_path:
183
- draw_path_3d_interactive(G, shortest_path, 'purple')
184
 
185
- return None # No need to return a filename for interactive plot
 
1
  import plotly.graph_objects as go # Import Plotly for interactive plots
2
+ from mpl_toolkits.mplot3d import Axes3D # Not needed anymore, but you can keep it if you use it elsewhere
3
  import networkx as nx
4
  import numpy as np
5
  import json
 
32
  return node_count_per_depth
33
 
34
 
 
 
35
  def build_graph_from_json(json_data, G):
36
  """Builds a graph from JSON data, handling subevents recursively."""
37
 
 
57
  add_event(None, event_data, 0) # Add each event as a root node
58
 
59
 
 
60
  def find_paths(G):
61
  """Finds paths with highest/lowest probability and longest/shortest durations."""
62
  best_path, worst_path = None, None
 
92
 
93
  return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
94
 
95
+ def draw_graph_plotly(G, title="3D Event Tree", highlight_color='gray'):
96
+ """Draws the graph in 3D using Plotly and returns the HTML string."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  pos = nx.get_node_attributes(G, 'pos')
98
  labels = nx.get_node_attributes(G, 'label')
99
 
100
  if not pos:
101
  print("Graph is empty. No nodes to visualize.")
102
+ return ""
103
 
104
  x_vals, y_vals, z_vals = zip(*pos.values())
105
 
 
113
  x_start, y_start, z_start = pos[edge[0]]
114
  x_end, y_end, z_end = pos[edge[1]]
115
  edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
116
+ mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
117
  edge_traces.append(edge_trace)
118
 
119
  layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
120
+ title=title)
121
  fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
 
122
 
123
+ # Convert Plotly figure to HTML string
124
+ html_str = fig.to_html(full_html=False, include_plotlyjs='cdn')
125
+ return html_str
126
 
127
  def main(json_data):
128
  G = nx.DiGraph()
129
+ build_graph_from_json(json_data, G)
 
 
 
130
 
131
+ # Generate the HTML string for the Plotly graph
132
+ html_graph = draw_graph_plotly(G)
133
 
134
+ # ... (Rest of your code for finding paths)
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if best_path:
137
+ best_path_graph = draw_graph_plotly(G.subgraph(best_path), title="Best Path", highlight_color='blue')
138
+ html_graph += best_path_graph
139
  if worst_path:
140
+ worst_path_graph = draw_graph_plotly(G.subgraph(worst_path), title="Worst Path", highlight_color='red')
141
+ html_graph += worst_path_graph
142
+ # ... (Similar for longest_path and shortest_path)
 
 
143
 
144
+ return html_graph # Return the HTML string