Spaces:
Sleeping
Sleeping
File size: 6,001 Bytes
81dcb15 42a21e8 a4844a1 6840d6b a4844a1 6840d6b a4844a1 6840d6b a4844a1 b3feaa3 a4844a1 6840d6b a4844a1 b3feaa3 a4844a1 6840d6b a4844a1 33d7cfe b3feaa3 6840d6b b3feaa3 33d7cfe b3feaa3 6840d6b 0e3c4d0 a4844a1 6840d6b b3feaa3 6840d6b 0e3c4d0 6840d6b 02b756e 6840d6b a4844a1 b3feaa3 a4844a1 42a21e8 a4844a1 6840d6b b3feaa3 6840d6b 42a21e8 6840d6b a4844a1 b3feaa3 81dcb15 b3feaa3 81dcb15 a4844a1 81dcb15 42a21e8 81dcb15 a4844a1 81dcb15 42a21e8 81dcb15 a4844a1 42a21e8 bca33fd 711c0ff a4844a1 42a21e8 6840d6b 42a21e8 a4844a1 42a21e8 a4844a1 42a21e8 a4844a1 42a21e8 6840d6b 42a21e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import plotly.graph_objects as go # Import Plotly for interactive plots
from mpl_toolkits.mplot3d import Axes3D # Not needed anymore, but you can keep it if you use it elsewhere
import networkx as nx
import numpy as np
import json
import sys
import random
def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
"""Generates a tree of nodes with positions adjusted on the x-axis, y-axis, and number of nodes on the z-axis."""
if node_count_per_depth is None:
node_count_per_depth = {}
if depth > max_depth:
return node_count_per_depth
if depth not in node_count_per_depth:
node_count_per_depth[depth] = 0
num_children = random.randint(1, max_nodes)
x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
for x in x_positions:
node_id = len(G.nodes)
node_count_per_depth[depth] += 1
prob = random.uniform(0, 1)
G.add_node(node_id, pos=(x, prob, depth))
if parent is not None:
G.add_edge(parent, node_id)
generate_tree(x, current_y + 1, depth + 1, max_depth, max_nodes, x_range, G, parent=node_id, node_count_per_depth=node_count_per_depth)
return node_count_per_depth
def build_graph_from_json(json_data, G):
"""Builds a graph from JSON data, handling subevents recursively."""
def add_event(parent_id, event_data, depth):
node_id = len(G.nodes)
prob = event_data['probability'] / 100.0
# Use event_number as the z-coordinate for better visualization
pos = (depth, prob, event_data['event_number'])
label = event_data['name']
G.add_node(node_id, pos=pos, label=label)
if parent_id is not None:
G.add_edge(parent_id, node_id) # Connect to parent
subevents = event_data.get('subevents', {}).get('event', [])
if not isinstance(subevents, list):
subevents = [subevents]
for subevent in subevents:
add_event(node_id, subevent, depth + 1) # Recursively add subevents
# Iterate through all top-level events
for event_data in json_data.get('events', {}).values():
add_event(None, event_data, 0) # Add each event as a root node
def find_paths(G):
"""Finds paths with highest/lowest probability and longest/shortest durations."""
best_path, worst_path = None, None
longest_path, shortest_path = None, None
best_mean_prob, worst_mean_prob = -1, float('inf')
max_duration, min_duration = -1, float('inf')
# Use nx.all_pairs_shortest_path for efficiency
all_paths_dict = dict(nx.all_pairs_shortest_path(G))
for source, paths_from_source in all_paths_dict.items():
for target, path in paths_from_source.items():
if source != target and all('pos' in G.nodes[node] for node in path):
probabilities = [G.nodes[node]['pos'][1] for node in path]
mean_prob = np.mean(probabilities)
if mean_prob > best_mean_prob:
best_mean_prob = mean_prob
best_path = path
if mean_prob < worst_mean_prob:
worst_mean_prob = mean_prob
worst_path = path
x_positions = [G.nodes[node]['pos'][0] for node in path]
duration = max(x_positions) - min(x_positions)
if duration > max_duration:
max_duration = duration
longest_path = path
if duration < min_duration and duration > 0: # Avoid paths with 0 duration
min_duration = duration
shortest_path = path
return best_path, best_mean_prob, worst_path, worst_mean_prob, longest_path, shortest_path
def draw_graph_plotly(G, title="3D Event Tree", highlight_color='gray'):
"""Draws the graph in 3D using Plotly and returns the HTML string."""
pos = nx.get_node_attributes(G, 'pos')
labels = nx.get_node_attributes(G, 'label')
if not pos:
print("Graph is empty. No nodes to visualize.")
return ""
x_vals, y_vals, z_vals = zip(*pos.values())
node_colors = ['red' if prob < 0.33 else 'blue' if prob < 0.67 else 'green' for _, prob, _ in pos.values()]
node_trace = go.Scatter3d(x=x_vals, y=y_vals, z=z_vals, mode='markers+text',
marker=dict(size=10, color=node_colors, line=dict(width=1, color='black')),
text=list(labels.values()), textposition='top center', hoverinfo='text')
edge_traces = []
for edge in G.edges():
x_start, y_start, z_start = pos[edge[0]]
x_end, y_end, z_end = pos[edge[1]]
edge_trace = go.Scatter3d(x=[x_start, x_end], y=[y_start, y_end], z=[z_start, z_end],
mode='lines', line=dict(width=2, color=highlight_color), hoverinfo='none')
edge_traces.append(edge_trace)
layout = go.Layout(scene=dict(xaxis_title='Time', yaxis_title='Probability', zaxis_title='Event Number'),
title=title)
fig = go.Figure(data=[node_trace] + edge_traces, layout=layout)
# Convert Plotly figure to HTML string
html_str = fig.to_html(full_html=False, include_plotlyjs='cdn')
return html_str
def main(json_data):
G = nx.DiGraph()
build_graph_from_json(json_data, G)
# Generate the HTML string for the Plotly graph
html_graph = draw_graph_plotly(G)
# ... (Rest of your code for finding paths)
if best_path:
best_path_graph = draw_graph_plotly(G.subgraph(best_path), title="Best Path", highlight_color='blue')
html_graph += best_path_graph
if worst_path:
worst_path_graph = draw_graph_plotly(G.subgraph(worst_path), title="Worst Path", highlight_color='red')
html_graph += worst_path_graph
# ... (Similar for longest_path and shortest_path)
return html_graph # Return the HTML string
|