|  | import plotly.graph_objects as go | 
					
						
						|  | import plotly.express as px | 
					
						
						|  | import plotly.figure_factory as ff | 
					
						
						|  | from plotly.subplots import make_subplots | 
					
						
						|  | import networkx as nx | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | import pandas as pd | 
					
						
						|  | import logging | 
					
						
						|  |  | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  | class GraphVisualizer: | 
					
						
						|  | """Advanced graph visualization utilities""" | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def create_graph_plot(data, max_nodes=500, layout_algorithm='spring', node_size_factor=1.0): | 
					
						
						|  | """Create interactive graph visualization""" | 
					
						
						|  | try: | 
					
						
						|  | if not hasattr(data, 'edge_index') or not hasattr(data, 'num_nodes'): | 
					
						
						|  | raise ValueError("Data must have edge_index and num_nodes attributes") | 
					
						
						|  |  | 
					
						
						|  | num_nodes = min(data.num_nodes, max_nodes) | 
					
						
						|  | if num_nodes <= 0: | 
					
						
						|  | raise ValueError("No nodes to visualize") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | G = nx.Graph() | 
					
						
						|  |  | 
					
						
						|  | if data.edge_index.size(1) > 0: | 
					
						
						|  | edge_list = data.edge_index.t().cpu().numpy() | 
					
						
						|  | edge_list = edge_list[ | 
					
						
						|  | (edge_list[:, 0] < num_nodes) & (edge_list[:, 1] < num_nodes) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | if len(edge_list) > 0: | 
					
						
						|  | G.add_edges_from(edge_list) | 
					
						
						|  |  | 
					
						
						|  | G.add_nodes_from(range(num_nodes)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pos = nx.spring_layout(G, seed=42) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(data, 'y') and data.y is not None: | 
					
						
						|  | node_colors = data.y.cpu().numpy()[:num_nodes] | 
					
						
						|  | else: | 
					
						
						|  | node_colors = [0] * num_nodes | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | edge_x, edge_y = [], [] | 
					
						
						|  | for edge in G.edges(): | 
					
						
						|  | if edge[0] in pos and edge[1] in pos: | 
					
						
						|  | x0, y0 = pos[edge[0]] | 
					
						
						|  | x1, y1 = pos[edge[1]] | 
					
						
						|  | edge_x.extend([x0, x1, None]) | 
					
						
						|  | edge_y.extend([y0, y1, None]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | node_x = [pos[node][0] for node in G.nodes()] | 
					
						
						|  | node_y = [pos[node][1] for node in G.nodes()] | 
					
						
						|  |  | 
					
						
						|  | fig = go.Figure() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if edge_x: | 
					
						
						|  | fig.add_trace(go.Scatter( | 
					
						
						|  | x=edge_x, y=edge_y, | 
					
						
						|  | line=dict(width=0.8, color='rgba(125,125,125,0.5)'), | 
					
						
						|  | hoverinfo='none', | 
					
						
						|  | mode='lines', | 
					
						
						|  | showlegend=False | 
					
						
						|  | )) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace(go.Scatter( | 
					
						
						|  | x=node_x, y=node_y, | 
					
						
						|  | mode='markers', | 
					
						
						|  | marker=dict( | 
					
						
						|  | size=8, | 
					
						
						|  | color=node_colors, | 
					
						
						|  | colorscale='Viridis', | 
					
						
						|  | line=dict(width=2, color='white'), | 
					
						
						|  | opacity=0.8 | 
					
						
						|  | ), | 
					
						
						|  | text=[f"Node {i}" for i in range(len(node_x))], | 
					
						
						|  | hoverinfo='text', | 
					
						
						|  | showlegend=False | 
					
						
						|  | )) | 
					
						
						|  |  | 
					
						
						|  | fig.update_layout( | 
					
						
						|  | title=f'Graph Visualization ({num_nodes} nodes)', | 
					
						
						|  | showlegend=False, | 
					
						
						|  | hovermode='closest', | 
					
						
						|  | margin=dict(b=20, l=5, r=5, t=40), | 
					
						
						|  | xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | 
					
						
						|  | yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | 
					
						
						|  | plot_bgcolor='white', | 
					
						
						|  | width=800, | 
					
						
						|  | height=600 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return fig | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Graph visualization error: {e}") | 
					
						
						|  | return GraphVisualizer._create_error_figure(f"Visualization error: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def create_metrics_plot(metrics): | 
					
						
						|  | """Create comprehensive metrics visualization""" | 
					
						
						|  | try: | 
					
						
						|  | metric_names = [] | 
					
						
						|  | metric_values = [] | 
					
						
						|  |  | 
					
						
						|  | for key, value in metrics.items(): | 
					
						
						|  | if isinstance(value, (int, float)) and key not in ['error', 'loss']: | 
					
						
						|  | if not (np.isnan(value) or np.isinf(value)) and 0 <= value <= 1: | 
					
						
						|  | metric_names.append(key.replace('_', ' ').title()) | 
					
						
						|  | metric_values.append(value) | 
					
						
						|  |  | 
					
						
						|  | if not metric_names: | 
					
						
						|  | return GraphVisualizer._create_error_figure("No valid metrics to display") | 
					
						
						|  |  | 
					
						
						|  | fig = make_subplots( | 
					
						
						|  | rows=1, cols=2, | 
					
						
						|  | subplot_titles=('Performance Metrics', 'Metric Radar Chart'), | 
					
						
						|  | specs=[[{"type": "bar"}, {"type": "polar"}]] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | colors = px.colors.qualitative.Set3[:len(metric_names)] | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Bar( | 
					
						
						|  | x=metric_names, | 
					
						
						|  | y=metric_values, | 
					
						
						|  | marker_color=colors, | 
					
						
						|  | text=[f'{v:.3f}' for v in metric_values], | 
					
						
						|  | textposition='auto', | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatterpolar( | 
					
						
						|  | r=metric_values + [metric_values[0]], | 
					
						
						|  | theta=metric_names + [metric_names[0]], | 
					
						
						|  | fill='toself', | 
					
						
						|  | line=dict(color='blue'), | 
					
						
						|  | marker=dict(size=8), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | fig.update_layout( | 
					
						
						|  | title='Model Performance Dashboard', | 
					
						
						|  | height=400, | 
					
						
						|  | showlegend=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | fig.update_xaxes(title_text="Metrics", tickangle=45, row=1, col=1) | 
					
						
						|  | fig.update_yaxes(title_text="Score", range=[0, 1], row=1, col=1) | 
					
						
						|  |  | 
					
						
						|  | fig.update_polars( | 
					
						
						|  | radialaxis=dict(range=[0, 1], showticklabels=True), | 
					
						
						|  | row=1, col=2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return fig | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Metrics plot error: {e}") | 
					
						
						|  | return GraphVisualizer._create_error_figure(f"Metrics plot error: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def create_training_history_plot(history): | 
					
						
						|  | """Create comprehensive training history visualization""" | 
					
						
						|  | try: | 
					
						
						|  | if not isinstance(history, dict) or not history: | 
					
						
						|  | return GraphVisualizer._create_error_figure("No training history available") | 
					
						
						|  |  | 
					
						
						|  | required_keys = ['train_loss', 'train_acc'] | 
					
						
						|  | for key in required_keys: | 
					
						
						|  | if key not in history or not history[key]: | 
					
						
						|  | return GraphVisualizer._create_error_figure(f"Missing {key} in training history") | 
					
						
						|  |  | 
					
						
						|  | epochs = list(range(len(history['train_loss']))) | 
					
						
						|  |  | 
					
						
						|  | fig = make_subplots( | 
					
						
						|  | rows=2, cols=2, | 
					
						
						|  | subplot_titles=('Loss Over Time', 'Accuracy Over Time', 'Learning Rate', 'Training Progress'), | 
					
						
						|  | specs=[[{"secondary_y": False}, {"secondary_y": False}], | 
					
						
						|  | [{"secondary_y": False}, {"secondary_y": False}]] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatter( | 
					
						
						|  | x=epochs, y=history['train_loss'], | 
					
						
						|  | mode='lines', name='Train Loss', | 
					
						
						|  | line=dict(color='blue', width=2), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if 'val_loss' in history and history['val_loss']: | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatter( | 
					
						
						|  | x=epochs, y=history['val_loss'], | 
					
						
						|  | mode='lines', name='Val Loss', | 
					
						
						|  | line=dict(color='red', width=2), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatter( | 
					
						
						|  | x=epochs, y=history['train_acc'], | 
					
						
						|  | mode='lines', name='Train Acc', | 
					
						
						|  | line=dict(color='green', width=2), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if 'val_acc' in history and history['val_acc']: | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatter( | 
					
						
						|  | x=epochs, y=history['val_acc'], | 
					
						
						|  | mode='lines', name='Val Acc', | 
					
						
						|  | line=dict(color='orange', width=2), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=1, col=2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if 'lr' in history and history['lr']: | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Scatter( | 
					
						
						|  | x=epochs, y=history['lr'], | 
					
						
						|  | mode='lines', name='Learning Rate', | 
					
						
						|  | line=dict(color='purple', width=2), | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=2, col=1 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_metrics = { | 
					
						
						|  | 'Final Train Acc': history['train_acc'][-1] if history['train_acc'] else 0, | 
					
						
						|  | 'Final Train Loss': history['train_loss'][-1] if history['train_loss'] else 0, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if 'val_acc' in history and history['val_acc']: | 
					
						
						|  | final_metrics['Final Val Acc'] = history['val_acc'][-1] | 
					
						
						|  | final_metrics['Best Val Acc'] = max(history['val_acc']) | 
					
						
						|  |  | 
					
						
						|  | metric_names = list(final_metrics.keys()) | 
					
						
						|  | metric_values = list(final_metrics.values()) | 
					
						
						|  |  | 
					
						
						|  | fig.add_trace( | 
					
						
						|  | go.Bar( | 
					
						
						|  | x=metric_names, | 
					
						
						|  | y=metric_values, | 
					
						
						|  | marker_color=['lightblue', 'lightcoral', 'lightgreen', 'gold'], | 
					
						
						|  | text=[f'{v:.3f}' for v in metric_values], | 
					
						
						|  | textposition='auto', | 
					
						
						|  | showlegend=False | 
					
						
						|  | ), | 
					
						
						|  | row=2, col=2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | fig.update_layout( | 
					
						
						|  | title='Training History Dashboard', | 
					
						
						|  | height=600, | 
					
						
						|  | showlegend=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | fig.update_xaxes(title_text="Epoch", row=1, col=1) | 
					
						
						|  | fig.update_xaxes(title_text="Epoch", row=1, col=2) | 
					
						
						|  | fig.update_xaxes(title_text="Epoch", row=2, col=1) | 
					
						
						|  | fig.update_xaxes(title_text="Metric", tickangle=45, row=2, col=2) | 
					
						
						|  |  | 
					
						
						|  | fig.update_yaxes(title_text="Loss", row=1, col=1) | 
					
						
						|  | fig.update_yaxes(title_text="Accuracy", range=[0, 1], row=1, col=2) | 
					
						
						|  | fig.update_yaxes(title_text="Learning Rate", type="log", row=2, col=1) | 
					
						
						|  | fig.update_yaxes(title_text="Value", row=2, col=2) | 
					
						
						|  |  | 
					
						
						|  | return fig | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Training history plot error: {e}") | 
					
						
						|  | return GraphVisualizer._create_error_figure(f"Training history plot error: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _create_error_figure(error_message): | 
					
						
						|  | """Create an error figure with message""" | 
					
						
						|  | fig = go.Figure() | 
					
						
						|  | fig.add_annotation( | 
					
						
						|  | text=error_message, | 
					
						
						|  | x=0.5, y=0.5, | 
					
						
						|  | xref="paper", yref="paper", | 
					
						
						|  | showarrow=False, | 
					
						
						|  | font=dict(size=14, color="red"), | 
					
						
						|  | bgcolor="rgba(255,255,255,0.8)", | 
					
						
						|  | bordercolor="red", | 
					
						
						|  | borderwidth=1 | 
					
						
						|  | ) | 
					
						
						|  | fig.update_layout( | 
					
						
						|  | title="Visualization Error", | 
					
						
						|  | xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | 
					
						
						|  | yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | 
					
						
						|  | plot_bgcolor='white', | 
					
						
						|  | width=600, | 
					
						
						|  | height=400 | 
					
						
						|  | ) | 
					
						
						|  | return fig |