import plotly.graph_objects as go import networkx as nx import numpy as np from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges, Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource) from bokeh.palettes import Spectral4 from bokeh.plotting import from_networkx def create_graph(entities, relationships): G = nx.Graph() for entity_id, entity_data in entities.items(): G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})") for source, relation, target in relationships: G.add_edge(source, target, label=relation) return G def improved_spectral_layout(G, scale=1): pos = nx.spectral_layout(G) # Add some random noise to prevent overlapping pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()} # Scale the layout pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()} return pos def create_bokeh_plot(G, layout_type='spring'): plot = Plot(width=600, height=600, x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2)) plot.title.text = "Knowledge Graph Interaction" node_hover = HoverTool(tooltips=[("Entity", "@label")]) edge_hover = HoverTool(tooltips=[("Relation", "@label")]) plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool()) # Create layout based on layout_type if layout_type == 'spring': pos = nx.spring_layout(G, k=0.5, iterations=50) elif layout_type == 'fruchterman_reingold': pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50) elif layout_type == 'circular': pos = nx.circular_layout(G) elif layout_type == 'random': pos = nx.random_layout(G) elif layout_type == 'spectral': pos = improved_spectral_layout(G) elif layout_type == 'shell': pos = nx.shell_layout(G) else: pos = nx.spring_layout(G, k=0.5, iterations=50) graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0)) graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0]) graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2]) graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1]) graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3) graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4) graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3) graph_renderer.selection_policy = NodesAndLinkedEdges() graph_renderer.inspection_policy = NodesAndLinkedEdges() plot.renderers.append(graph_renderer) # Add node labels x, y = zip(*graph_renderer.layout_provider.graph_layout.values()) node_labels = nx.get_node_attributes(G, 'label') source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]}) labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white', text_font_size='8pt', background_fill_alpha=0.7) plot.renderers.append(labels) # Add edge labels edge_x, edge_y, edge_labels = [], [], [] for (start_node, end_node, label) in G.edges(data='label'): start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node] end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node] edge_x.append((start_x + end_x) / 2) edge_y.append((start_y + end_y) / 2) edge_labels.append(label) edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels}) edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source, background_fill_color='white', text_font_size='8pt', background_fill_alpha=0.7) plot.renderers.append(edge_labels) return plot def create_plotly_plot(G, layout_type='spring'): # Create layout based on layout_type if layout_type == 'spring': pos = nx.spring_layout(G, k=0.5, iterations=50) elif layout_type == 'fruchterman_reingold': pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50) elif layout_type == 'circular': pos = nx.circular_layout(G) elif layout_type == 'random': pos = nx.random_layout(G) elif layout_type == 'spectral': pos = improved_spectral_layout(G) elif layout_type == 'shell': pos = nx.shell_layout(G) else: pos = nx.spring_layout(G, k=0.5, iterations=50) edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[]) node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text", marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15, colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"), line_width=2), text=[], textposition="top center") edge_labels = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_trace["x"] += (x0, x1, None) edge_trace["y"] += (y0, y1, None) mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]], textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8))) for node in G.nodes(): x, y = pos[node] node_trace["x"] += (x,) node_trace["y"] += (y,) node_trace["text"] += (G.nodes[node]["label"],) node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) fig = go.Figure(data=[edge_trace, node_trace] + edge_labels, layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest", margin=dict(b=20, l=5, r=5, t=40), annotations=[], xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), width=800, height=600)) fig.update_layout(newshape=dict(line_color="#009900"), xaxis=dict(scaleanchor="y", scaleratio=1), yaxis=dict(scaleanchor="x", scaleratio=1)) return fig