Spaces:
Sleeping
Sleeping
neuralworm
commited on
Commit
•
42a21e8
1
Parent(s):
82f89ae
plotly instead of matplotlib
Browse files- app.py +3 -12
- psychohistory.py +19 -60
app.py
CHANGED
@@ -9,25 +9,16 @@ with gr.Blocks(title="PSYCHOHISTORY") as app:
|
|
9 |
btn_search = gr.Button("Look", scale=1)
|
10 |
with gr.Row():
|
11 |
mem_results = gr.JSON(label="Results")
|
|
|
12 |
btn_search.click(
|
13 |
gen.generate,
|
14 |
inputs=[txt_search],
|
15 |
outputs=mem_results
|
16 |
)
|
17 |
-
|
18 |
-
# with gr.Row():
|
19 |
-
# img_output = gr.Image(label="Graph Visualization", type="filepath") # Add an Image component
|
20 |
-
|
21 |
-
# # Trigger graph generation after JSON is generated
|
22 |
-
# mem_results.change(
|
23 |
-
# psychohistory.main,
|
24 |
-
# inputs=[mem_results],
|
25 |
-
# outputs=img_output
|
26 |
-
# )
|
27 |
-
mem_results.change(
|
28 |
psychohistory.main,
|
29 |
inputs=[mem_results],
|
30 |
-
outputs=
|
31 |
)
|
32 |
|
33 |
if __name__ == "__main__":
|
|
|
9 |
btn_search = gr.Button("Look", scale=1)
|
10 |
with gr.Row():
|
11 |
mem_results = gr.JSON(label="Results")
|
12 |
+
html_output = gr.HTML(label="Graph Visualization") # Use HTML component
|
13 |
btn_search.click(
|
14 |
gen.generate,
|
15 |
inputs=[txt_search],
|
16 |
outputs=mem_results
|
17 |
)
|
18 |
+
mem_results.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
psychohistory.main,
|
20 |
inputs=[mem_results],
|
21 |
+
outputs=html_output # Output to the HTML component
|
22 |
)
|
23 |
|
24 |
if __name__ == "__main__":
|
psychohistory.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import plotly.graph_objects as go # Import Plotly for interactive plots
|
2 |
-
from mpl_toolkits.mplot3d import Axes3D
|
3 |
import networkx as nx
|
4 |
import numpy as np
|
5 |
import json
|
@@ -32,8 +32,6 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
|
|
32 |
return node_count_per_depth
|
33 |
|
34 |
|
35 |
-
|
36 |
-
|
37 |
def build_graph_from_json(json_data, G):
|
38 |
"""Builds a graph from JSON data, handling subevents recursively."""
|
39 |
|
@@ -59,7 +57,6 @@ def build_graph_from_json(json_data, G):
|
|
59 |
add_event(None, event_data, 0) # Add each event as a root node
|
60 |
|
61 |
|
62 |
-
|
63 |
def find_paths(G):
|
64 |
"""Finds paths with highest/lowest probability and longest/shortest durations."""
|
65 |
best_path, worst_path = None, None
|
@@ -95,40 +92,14 @@ def find_paths(G):
|
|
95 |
|
96 |
return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
"""Draws a specific path in 3D using Plotly for interactivity."""
|
101 |
-
H = G.subgraph(path).copy()
|
102 |
-
pos = nx.get_node_attributes(G, 'pos')
|
103 |
-
x_vals, y_vals, z_vals = zip(*[pos[node] for node in path])
|
104 |
-
|
105 |
-
node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in [pos[node] for node in path]]
|
106 |
-
node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
|
107 |
-
marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
|
108 |
-
text=list(map(str, path)), textposition='top center', hoverinfo='text')
|
109 |
-
|
110 |
-
edge_traces = []
|
111 |
-
for edge in H.edges():
|
112 |
-
x_start, y_start, z_start = pos[edge[0]]
|
113 |
-
x_end, y_end, z_end = pos[edge[1]]
|
114 |
-
edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
|
115 |
-
mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
|
116 |
-
edge_traces.append(edge_trace)
|
117 |
-
|
118 |
-
layout = go.Layout(scene=dict(xaxis_title='Time (weeks)', yaxis_title='Event Probability', zaxis_title='Event Number'),
|
119 |
-
title='3D Event Tree - Path')
|
120 |
-
fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
|
121 |
-
fig.show()
|
122 |
-
|
123 |
-
|
124 |
-
def draw_global_tree_3d_interactive(G):
|
125 |
-
"""Draws the entire graph in 3D using Plotly for interactivity."""
|
126 |
pos = nx.get_node_attributes(G, 'pos')
|
127 |
labels = nx.get_node_attributes(G, 'label')
|
128 |
|
129 |
if not pos:
|
130 |
print("Graph is empty. No nodes to visualize.")
|
131 |
-
return
|
132 |
|
133 |
x_vals, y_vals, z_vals = zip(*pos.values())
|
134 |
|
@@ -142,44 +113,32 @@ def draw_global_tree_3d_interactive(G):
|
|
142 |
x_start, y_start, z_start = pos[edge[0]]
|
143 |
x_end, y_end, z_end = pos[edge[1]]
|
144 |
edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
|
145 |
-
mode='lines', line=dict(width=2, color=
|
146 |
edge_traces.append(edge_trace)
|
147 |
|
148 |
layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
|
149 |
-
title=
|
150 |
fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
|
151 |
-
fig.show()
|
152 |
|
|
|
|
|
|
|
153 |
|
154 |
def main(json_data):
|
155 |
G = nx.DiGraph()
|
156 |
-
build_graph_from_json(json_data, G)
|
157 |
-
|
158 |
-
# Draw the interactive graph using Plotly
|
159 |
-
draw_global_tree_3d_interactive(G)
|
160 |
|
161 |
-
|
|
|
162 |
|
163 |
-
|
164 |
-
print(f"\nPath with the highest average probability: {' -> '.join(map(str, best_path))}")
|
165 |
-
print(f"Average probability: {best_mean_prob:.2f}")
|
166 |
-
if worst_path:
|
167 |
-
print(f"\nPath with the lowest average probability: {' -> '.join(map(str, worst_path))}")
|
168 |
-
print(f"Average probability: {worst_mean_prob:.2f}")
|
169 |
-
if longest_path:
|
170 |
-
print(f"\nPath with the longest duration: {' -> '.join(map(str, longest_path))}")
|
171 |
-
print(f"Duration: {max(G.nodes[node]['pos'][0] for node in longest_path) - min(G.nodes[node]['pos'][0] for node in longest_path):.2f}")
|
172 |
-
if shortest_path:
|
173 |
-
print(f"\nPath with the shortest duration: {' -> '.join(map(str, shortest_path))}")
|
174 |
-
print(f"Duration: {max(G.nodes[node]['pos'][0] for node in shortest_path) - min(G.nodes[node]['pos'][0] for node in shortest_path):.2f}")
|
175 |
|
176 |
if best_path:
|
177 |
-
|
|
|
178 |
if worst_path:
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
if shortest_path:
|
183 |
-
draw_path_3d_interactive(G, shortest_path, 'purple')
|
184 |
|
185 |
-
return
|
|
|
1 |
import plotly.graph_objects as go # Import Plotly for interactive plots
|
2 |
+
from mpl_toolkits.mplot3d import Axes3D # Not needed anymore, but you can keep it if you use it elsewhere
|
3 |
import networkx as nx
|
4 |
import numpy as np
|
5 |
import json
|
|
|
32 |
return node_count_per_depth
|
33 |
|
34 |
|
|
|
|
|
35 |
def build_graph_from_json(json_data, G):
|
36 |
"""Builds a graph from JSON data, handling subevents recursively."""
|
37 |
|
|
|
57 |
add_event(None, event_data, 0) # Add each event as a root node
|
58 |
|
59 |
|
|
|
60 |
def find_paths(G):
|
61 |
"""Finds paths with highest/lowest probability and longest/shortest durations."""
|
62 |
best_path, worst_path = None, None
|
|
|
92 |
|
93 |
return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
|
94 |
|
95 |
+
def draw_graph_plotly(G, title="3D Event Tree", highlight_color='gray'):
|
96 |
+
"""Draws the graph in 3D using Plotly and returns the HTML string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
pos = nx.get_node_attributes(G, 'pos')
|
98 |
labels = nx.get_node_attributes(G, 'label')
|
99 |
|
100 |
if not pos:
|
101 |
print("Graph is empty. No nodes to visualize.")
|
102 |
+
return ""
|
103 |
|
104 |
x_vals, y_vals, z_vals = zip(*pos.values())
|
105 |
|
|
|
113 |
x_start, y_start, z_start = pos[edge[0]]
|
114 |
x_end, y_end, z_end = pos[edge[1]]
|
115 |
edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
|
116 |
+
mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
|
117 |
edge_traces.append(edge_trace)
|
118 |
|
119 |
layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
|
120 |
+
title=title)
|
121 |
fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
|
|
|
122 |
|
123 |
+
# Convert Plotly figure to HTML string
|
124 |
+
html_str = fig.to_html(full_html=False, include_plotlyjs='cdn')
|
125 |
+
return html_str
|
126 |
|
127 |
def main(json_data):
|
128 |
G = nx.DiGraph()
|
129 |
+
build_graph_from_json(json_data, G)
|
|
|
|
|
|
|
130 |
|
131 |
+
# Generate the HTML string for the Plotly graph
|
132 |
+
html_graph = draw_graph_plotly(G)
|
133 |
|
134 |
+
# ... (Rest of your code for finding paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
if best_path:
|
137 |
+
best_path_graph = draw_graph_plotly(G.subgraph(best_path), title="Best Path", highlight_color='blue')
|
138 |
+
html_graph += best_path_graph
|
139 |
if worst_path:
|
140 |
+
worst_path_graph = draw_graph_plotly(G.subgraph(worst_path), title="Worst Path", highlight_color='red')
|
141 |
+
html_graph += worst_path_graph
|
142 |
+
# ... (Similar for longest_path and shortest_path)
|
|
|
|
|
143 |
|
144 |
+
return html_graph # Return the HTML string
|