# src/knowledge_graph.py
import networkx as nx
from pyvis.network import Network
import json
from typing import Dict, List, Any, Optional, Set, Tuple
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from collections import defaultdict
class KnowledgeGraph:
"""
Handles the construction and visualization of knowledge graphs
based on the ontology data.
"""
def __init__(self, ontology_manager=None):
"""
Initialize the knowledge graph handler.
Args:
ontology_manager: Optional ontology manager instance
"""
self.ontology_manager = ontology_manager
self.graph = None
if ontology_manager:
self.graph = ontology_manager.graph
def build_visualization_graph(
self,
include_classes: bool = True,
include_instances: bool = True,
central_entity: Optional[str] = None,
max_distance: int = 2,
include_properties: bool = False
) -> nx.Graph:
"""
Build a simplified graph for visualization purposes.
Args:
include_classes: Whether to include class nodes
include_instances: Whether to include instance nodes
central_entity: Optional central entity to focus the graph on
max_distance: Maximum distance from central entity to include
include_properties: Whether to include property nodes
Returns:
A NetworkX graph suitable for visualization
"""
if not self.graph:
return nx.Graph()
# Create an undirected graph for visualization
viz_graph = nx.Graph()
# If we have a central entity, extract a subgraph around it
if central_entity and central_entity in self.graph:
# Get nodes within max_distance of central_entity
nodes_to_include = set([central_entity])
current_distance = 0
current_layer = set([central_entity])
while current_distance < max_distance:
next_layer = set()
for node in current_layer:
# Get neighbors
neighbors = set(self.graph.successors(node)).union(set(self.graph.predecessors(node)))
next_layer.update(neighbors)
nodes_to_include.update(next_layer)
current_layer = next_layer
current_distance += 1
# Create subgraph
subgraph = self.graph.subgraph(nodes_to_include)
else:
subgraph = self.graph
# Add nodes to the visualization graph
for node, data in subgraph.nodes(data=True):
node_type = data.get("type")
# Skip nodes based on configuration
if node_type == "class" and not include_classes:
continue
if node_type == "instance" and not include_instances:
continue
# Get readable name for the node
if node_type == "instance" and "properties" in data:
label = data["properties"].get("name", node)
else:
label = node
# Set node attributes for visualization
viz_attrs = {
"id": node,
"label": label,
"title": self._get_node_tooltip(node, data),
"group": data.get("class_type", node_type),
"shape": "dot" if node_type == "instance" else "diamond"
}
# Highlight central entity if specified
if central_entity and node == central_entity:
viz_attrs["color"] = "#ff7f0e" # Orange for central entity
viz_attrs["size"] = 25 # Larger size for central entity
# Add the node
viz_graph.add_node(node, **viz_attrs)
# Add property nodes if configured
if include_properties and node_type == "instance" and "properties" in data:
for prop_name, prop_value in data["properties"].items():
# Create a property node
prop_node_id = f"{node}_{prop_name}"
prop_value_str = str(prop_value)
if len(prop_value_str) > 20:
prop_value_str = prop_value_str[:17] + "..."
viz_graph.add_node(
prop_node_id,
id=prop_node_id,
label=f"{prop_name}: {prop_value_str}",
title=f"{prop_name}: {prop_value}",
group="property",
shape="ellipse",
size=5
)
# Connect instance to property
viz_graph.add_edge(node, prop_node_id, label="has_property", dashes=True)
# Add edges to the visualization graph
for source, target, data in subgraph.edges(data=True):
# Only include edges between nodes that are in the viz_graph
if source in viz_graph and target in viz_graph:
# Skip property-related edges if we're manually creating them
if include_properties and (
source.startswith(target + "_") or target.startswith(source + "_")
):
continue
# Set edge attributes
edge_type = data.get("type", "unknown")
# Don't show subClassOf and instanceOf relationships if not explicitly requested
if edge_type in ["subClassOf", "instanceOf"] and not include_classes:
continue
viz_graph.add_edge(source, target, label=edge_type, title=edge_type)
return viz_graph
def _get_node_tooltip(self, node_id: str, data: Dict) -> str:
"""Generate a tooltip for a node."""
tooltip = f"ID: {node_id}
"
node_type = data.get("type")
if node_type:
tooltip += f"Type: {node_type}
"
if node_type == "instance":
tooltip += f"Class: {data.get('class_type', 'unknown')}
"
# Add properties
if "properties" in data:
tooltip += "Properties:
"
for key, value in data["properties"].items():
tooltip += f"- {key}: {value}
"
elif node_type == "class":
tooltip += f"Description: {data.get('description', '')}
"
# Add properties if available
if "properties" in data:
tooltip += "Properties: " + ", ".join(data["properties"]) + "
"
return tooltip
def generate_html_visualization(
self,
include_classes: bool = True,
include_instances: bool = True,
central_entity: Optional[str] = None,
max_distance: int = 2,
include_properties: bool = False,
height: str = "600px",
width: str = "100%",
bgcolor: str = "#ffffff",
font_color: str = "#000000",
layout_algorithm: str = "force-directed"
) -> str:
"""
Generate an HTML visualization of the knowledge graph.
Args:
include_classes: Whether to include class nodes
include_instances: Whether to include instance nodes
central_entity: Optional central entity to focus the graph on
max_distance: Maximum distance from central entity to include
include_properties: Whether to include property nodes
height: Height of the visualization
width: Width of the visualization
bgcolor: Background color
font_color: Font color
layout_algorithm: Algorithm for layout ('force-directed', 'hierarchical', 'radial', 'circular')
Returns:
HTML string containing the visualization
"""
# Build the visualization graph
viz_graph = self.build_visualization_graph(
include_classes=include_classes,
include_instances=include_instances,
central_entity=central_entity,
max_distance=max_distance,
include_properties=include_properties
)
# Create a PyVis network
net = Network(height=height, width=width, bgcolor=bgcolor, font_color=font_color, directed=True)
# Configure physics based on the selected layout algorithm
if layout_algorithm == "force-directed":
physics_options = {
"enabled": True,
"solver": "forceAtlas2Based",
"forceAtlas2Based": {
"gravitationalConstant": -50,
"centralGravity": 0.01,
"springLength": 100,
"springConstant": 0.08
},
"stabilization": {
"enabled": True,
"iterations": 100
}
}
elif layout_algorithm == "hierarchical":
physics_options = {
"enabled": True,
"hierarchicalRepulsion": {
"centralGravity": 0.0,
"springLength": 100,
"springConstant": 0.01,
"nodeDistance": 120
},
"solver": "hierarchicalRepulsion",
"stabilization": {
"enabled": True,
"iterations": 100
}
}
# Set hierarchical layout
net.set_options("""
var options = {
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed",
"nodeSpacing": 150,
"treeSpacing": 200
}
}
}
""")
elif layout_algorithm == "radial":
physics_options = {
"enabled": True,
"solver": "repulsion",
"repulsion": {
"nodeDistance": 120,
"centralGravity": 0.2,
"springLength": 200,
"springConstant": 0.05
},
"stabilization": {
"enabled": True,
"iterations": 100
}
}
elif layout_algorithm == "circular":
physics_options = {
"enabled": False
}
# Compute circular layout and set fixed positions
pos = nx.circular_layout(viz_graph)
for node_id, coords in pos.items():
if node_id in viz_graph.nodes:
x, y = coords
viz_graph.nodes[node_id]['x'] = float(x) * 500
viz_graph.nodes[node_id]['y'] = float(y) * 500
viz_graph.nodes[node_id]['physics'] = False
# Configure other options
options = {
"nodes": {
"font": {"size": 12},
"scaling": {"min": 10, "max": 30}
},
"edges": {
"color": {"inherit": True},
"smooth": {"enabled": True, "type": "dynamic"},
"arrows": {"to": {"enabled": True, "scaleFactor": 0.5}},
"font": {"size": 10, "align": "middle"}
},
"physics": physics_options,
"interaction": {
"hover": True,
"navigationButtons": True,
"keyboard": True,
"tooltipDelay": 100
}
}
# Set options and create the network
net.options = options
net.from_nx(viz_graph)
# Add custom CSS for better visualization
custom_css = """
"""
# Generate the HTML and add custom CSS
html = net.generate_html()
html = html.replace("