# src/knowledge_graph.py import networkx as nx from pyvis.network import Network import json from typing import Dict, List, Any, Optional, Set, Tuple import matplotlib.pyplot as plt import matplotlib.colors as mcolors from collections import defaultdict class KnowledgeGraph: """ Handles the construction and visualization of knowledge graphs based on the ontology data. """ def __init__(self, ontology_manager=None): """ Initialize the knowledge graph handler. Args: ontology_manager: Optional ontology manager instance """ self.ontology_manager = ontology_manager self.graph = None if ontology_manager: self.graph = ontology_manager.graph def build_visualization_graph( self, include_classes: bool = True, include_instances: bool = True, central_entity: Optional[str] = None, max_distance: int = 2, include_properties: bool = False ) -> nx.Graph: """ Build a simplified graph for visualization purposes. Args: include_classes: Whether to include class nodes include_instances: Whether to include instance nodes central_entity: Optional central entity to focus the graph on max_distance: Maximum distance from central entity to include include_properties: Whether to include property nodes Returns: A NetworkX graph suitable for visualization """ if not self.graph: return nx.Graph() # Create an undirected graph for visualization viz_graph = nx.Graph() # If we have a central entity, extract a subgraph around it if central_entity and central_entity in self.graph: # Get nodes within max_distance of central_entity nodes_to_include = set([central_entity]) current_distance = 0 current_layer = set([central_entity]) while current_distance < max_distance: next_layer = set() for node in current_layer: # Get neighbors neighbors = set(self.graph.successors(node)).union(set(self.graph.predecessors(node))) next_layer.update(neighbors) nodes_to_include.update(next_layer) current_layer = next_layer current_distance += 1 # Create subgraph subgraph = self.graph.subgraph(nodes_to_include) else: subgraph = self.graph # Add nodes to the visualization graph for node, data in subgraph.nodes(data=True): node_type = data.get("type") # Skip nodes based on configuration if node_type == "class" and not include_classes: continue if node_type == "instance" and not include_instances: continue # Get readable name for the node if node_type == "instance" and "properties" in data: label = data["properties"].get("name", node) else: label = node # Set node attributes for visualization viz_attrs = { "id": node, "label": label, "title": self._get_node_tooltip(node, data), "group": data.get("class_type", node_type), "shape": "dot" if node_type == "instance" else "diamond" } # Highlight central entity if specified if central_entity and node == central_entity: viz_attrs["color"] = "#ff7f0e" # Orange for central entity viz_attrs["size"] = 25 # Larger size for central entity # Add the node viz_graph.add_node(node, **viz_attrs) # Add property nodes if configured if include_properties and node_type == "instance" and "properties" in data: for prop_name, prop_value in data["properties"].items(): # Create a property node prop_node_id = f"{node}_{prop_name}" prop_value_str = str(prop_value) if len(prop_value_str) > 20: prop_value_str = prop_value_str[:17] + "..." viz_graph.add_node( prop_node_id, id=prop_node_id, label=f"{prop_name}: {prop_value_str}", title=f"{prop_name}: {prop_value}", group="property", shape="ellipse", size=5 ) # Connect instance to property viz_graph.add_edge(node, prop_node_id, label="has_property", dashes=True) # Add edges to the visualization graph for source, target, data in subgraph.edges(data=True): # Only include edges between nodes that are in the viz_graph if source in viz_graph and target in viz_graph: # Skip property-related edges if we're manually creating them if include_properties and ( source.startswith(target + "_") or target.startswith(source + "_") ): continue # Set edge attributes edge_type = data.get("type", "unknown") # Don't show subClassOf and instanceOf relationships if not explicitly requested if edge_type in ["subClassOf", "instanceOf"] and not include_classes: continue viz_graph.add_edge(source, target, label=edge_type, title=edge_type) return viz_graph def _get_node_tooltip(self, node_id: str, data: Dict) -> str: """Generate a tooltip for a node.""" tooltip = f"ID: {node_id}
" node_type = data.get("type") if node_type: tooltip += f"Type: {node_type}
" if node_type == "instance": tooltip += f"Class: {data.get('class_type', 'unknown')}
" # Add properties if "properties" in data: tooltip += "Properties:
" for key, value in data["properties"].items(): tooltip += f"- {key}: {value}
" elif node_type == "class": tooltip += f"Description: {data.get('description', '')}
" # Add properties if available if "properties" in data: tooltip += "Properties: " + ", ".join(data["properties"]) + "
" return tooltip def generate_html_visualization( self, include_classes: bool = True, include_instances: bool = True, central_entity: Optional[str] = None, max_distance: int = 2, include_properties: bool = False, height: str = "600px", width: str = "100%", bgcolor: str = "#ffffff", font_color: str = "#000000", layout_algorithm: str = "force-directed" ) -> str: """ Generate an HTML visualization of the knowledge graph. Args: include_classes: Whether to include class nodes include_instances: Whether to include instance nodes central_entity: Optional central entity to focus the graph on max_distance: Maximum distance from central entity to include include_properties: Whether to include property nodes height: Height of the visualization width: Width of the visualization bgcolor: Background color font_color: Font color layout_algorithm: Algorithm for layout ('force-directed', 'hierarchical', 'radial', 'circular') Returns: HTML string containing the visualization """ # Build the visualization graph viz_graph = self.build_visualization_graph( include_classes=include_classes, include_instances=include_instances, central_entity=central_entity, max_distance=max_distance, include_properties=include_properties ) # Create a PyVis network net = Network(height=height, width=width, bgcolor=bgcolor, font_color=font_color, directed=True) # Configure physics based on the selected layout algorithm if layout_algorithm == "force-directed": physics_options = { "enabled": True, "solver": "forceAtlas2Based", "forceAtlas2Based": { "gravitationalConstant": -50, "centralGravity": 0.01, "springLength": 100, "springConstant": 0.08 }, "stabilization": { "enabled": True, "iterations": 100 } } elif layout_algorithm == "hierarchical": physics_options = { "enabled": True, "hierarchicalRepulsion": { "centralGravity": 0.0, "springLength": 100, "springConstant": 0.01, "nodeDistance": 120 }, "solver": "hierarchicalRepulsion", "stabilization": { "enabled": True, "iterations": 100 } } # Set hierarchical layout net.set_options(""" var options = { "layout": { "hierarchical": { "enabled": true, "direction": "UD", "sortMethod": "directed", "nodeSpacing": 150, "treeSpacing": 200 } } } """) elif layout_algorithm == "radial": physics_options = { "enabled": True, "solver": "repulsion", "repulsion": { "nodeDistance": 120, "centralGravity": 0.2, "springLength": 200, "springConstant": 0.05 }, "stabilization": { "enabled": True, "iterations": 100 } } elif layout_algorithm == "circular": physics_options = { "enabled": False } # Compute circular layout and set fixed positions pos = nx.circular_layout(viz_graph) for node_id, coords in pos.items(): if node_id in viz_graph.nodes: x, y = coords viz_graph.nodes[node_id]['x'] = float(x) * 500 viz_graph.nodes[node_id]['y'] = float(y) * 500 viz_graph.nodes[node_id]['physics'] = False # Configure other options options = { "nodes": { "font": {"size": 12}, "scaling": {"min": 10, "max": 30} }, "edges": { "color": {"inherit": True}, "smooth": {"enabled": True, "type": "dynamic"}, "arrows": {"to": {"enabled": True, "scaleFactor": 0.5}}, "font": {"size": 10, "align": "middle"} }, "physics": physics_options, "interaction": { "hover": True, "navigationButtons": True, "keyboard": True, "tooltipDelay": 100 } } # Set options and create the network net.options = options net.from_nx(viz_graph) # Add custom CSS for better visualization custom_css = """ """ # Generate the HTML and add custom CSS html = net.generate_html() html = html.replace("