#!/usr/bin/env python3 """ Production test script for Mamba Graph implementation Fixed for overfitting with regularized configuration """ import os os.environ['OMP_NUM_THREADS'] = '4' # Fix warning import torch import time import logging from pathlib import Path from core.graph_mamba import GraphMamba, create_regularized_config from core.trainer import GraphMambaTrainer from data.loader import GraphDataLoader from utils.metrics import GraphMetrics from utils.visualization import GraphVisualizer # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def get_device(): """Get the best available device - GPU preferred""" if torch.cuda.is_available(): device = torch.device('cuda') logger.info(f"šŸš€ CUDA available - using GPU: {torch.cuda.get_device_name()}") else: device = torch.device('cpu') logger.info("šŸ’» Using CPU") return device def run_comprehensive_test(): """Run comprehensive test suite with overfitting fixes""" print("🧠 Mamba Graph Neural Network - Complete Test") print("=" * 60) # Use regularized configuration to prevent overfitting config = create_regularized_config() # Setup device device = get_device() start_time = time.time() # Test results test_results = { 'data_loading': False, 'model_initialization': False, 'forward_pass': False, 'ordering_strategies': {}, 'training': False, 'evaluation': False, 'visualization': False } try: # Test 1: Data Loading print("\nšŸ“Š Loading Cora dataset...") data_loader = GraphDataLoader() dataset = data_loader.load_node_classification_data('Cora') data = dataset[0].to(device) info = data_loader.get_dataset_info(dataset) print(f"āœ… Dataset loaded successfully!") print(f" Nodes: {data.num_nodes:,}") print(f" Edges: {data.num_edges:,}") print(f" Features: {info['num_features']}") print(f" Classes: {info['num_classes']}") print(f" Train nodes: {data.train_mask.sum()}") print(f" Val nodes: {data.val_mask.sum()}") print(f" Test nodes: {data.test_mask.sum()}") test_results['data_loading'] = True except Exception as e: print(f"āŒ Data loading failed: {e}") return test_results try: # Test 2: Model Initialization with regularized config print("\nšŸ—ļø Initializing GraphMamba (Regularized)...") model = GraphMamba(config).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"āœ… Model initialized!") print(f" Parameters: {total_params:,}") print(f" Memory usage: ~{total_params * 4 / 1024**2:.1f} MB") print(f" Device: {device}") print(f" Model type: Regularized (Anti-overfitting)") # Check if parameter count is reasonable for small training set train_samples = data.train_mask.sum().item() params_per_sample = total_params / train_samples print(f" Params per training sample: {params_per_sample:.1f}") if params_per_sample < 500: print(" āœ… Good parameter ratio - low overfitting risk") elif params_per_sample < 1000: print(" āš ļø Moderate parameter ratio - watch for overfitting") else: print(" 🚨 High parameter ratio - high overfitting risk") test_results['model_initialization'] = True except Exception as e: print(f"āŒ Model initialization failed: {e}") return test_results try: # Test 3: Forward Pass print("\nšŸš€ Testing forward pass...") model.eval() with torch.no_grad(): forward_start = time.time() h = model(data.x, data.edge_index) forward_time = time.time() - forward_start print(f"āœ… Forward pass successful!") print(f" Input shape: {data.x.shape}") print(f" Output shape: {h.shape}") print(f" Forward time: {forward_time*1000:.2f}ms") print(f" Output range: [{h.min():.3f}, {h.max():.3f}]") test_results['forward_pass'] = True except Exception as e: print(f"āŒ Forward pass failed: {e}") return test_results # Test 4: Ordering Strategies (simplified for regularized model) print("\nšŸ”„ Testing ordering strategies...") # Only test BFS for regularized model to avoid complexity strategies = ['bfs'] for strategy in strategies: try: config['ordering']['strategy'] = strategy test_model = GraphMamba(config).to(device) test_model.eval() strategy_start = time.time() with torch.no_grad(): h = test_model(data.x, data.edge_index) strategy_time = time.time() - strategy_start print(f"āœ… {strategy:12} | Shape: {h.shape} | Time: {strategy_time*1000:.2f}ms") test_results['ordering_strategies'][strategy] = True except Exception as e: print(f"āŒ {strategy:12} | Failed: {str(e)}") test_results['ordering_strategies'][strategy] = False try: # Test 5: Regularized Training print("\nšŸ‹ļø Testing regularized training system...") # Reset to BFS for training config['ordering']['strategy'] = 'bfs' model = GraphMamba(config).to(device) trainer = GraphMambaTrainer(model, config, device) print(f"āœ… Trainer initialized!") print(f" Optimizer: {type(trainer.optimizer).__name__}") print(f" Learning rate: {trainer.lr}") print(f" Epochs: {trainer.epochs}") print(f" Weight decay: {config['training']['weight_decay']}") print(f" Anti-overfitting: Enabled") # Run training print(f"\nšŸŽÆ Running regularized training...") training_start = time.time() history = trainer.train_node_classification(data, verbose=True) training_time = time.time() - training_start print(f"āœ… Training completed!") print(f" Training time: {training_time:.2f}s") print(f" Epochs trained: {len(history['train_loss'])}") print(f" Best val accuracy: {trainer.best_val_acc:.4f}") print(f" Final train accuracy: {history['train_acc'][-1]:.4f}") print(f" Overfitting gap: {trainer.best_gap:.4f}") test_results['training'] = True except Exception as e: print(f"āŒ Training failed: {e}") return test_results try: # Test 6: Evaluation print("\nšŸ“Š Testing evaluation...") test_metrics = trainer.test(data) print(f"āœ… Evaluation completed!") print(f" Test accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") print(f" Test loss: {test_metrics['test_loss']:.4f}") print(f" F1 macro: {test_metrics.get('f1_macro', 0):.4f}") print(f" F1 micro: {test_metrics.get('f1_micro', 0):.4f}") print(f" Precision: {test_metrics.get('precision', 0):.4f}") print(f" Recall: {test_metrics.get('recall', 0):.4f}") test_results['evaluation'] = True except Exception as e: print(f"āŒ Evaluation failed: {e}") return test_results try: # Test 7: Visualization print("\nšŸŽØ Testing visualization...") # Create all visualizations graph_fig = GraphVisualizer.create_graph_plot(data, max_nodes=200) metrics_fig = GraphVisualizer.create_metrics_plot(test_metrics) training_fig = GraphVisualizer.create_training_history_plot(history) print(f"āœ… Visualizations created!") print(f" Graph plot: {type(graph_fig).__name__}") print(f" Metrics plot: {type(metrics_fig).__name__}") print(f" Training plot: {type(training_fig).__name__}") test_results['visualization'] = True except Exception as e: print(f"āŒ Visualization failed: {e}") return test_results # Final Summary print("\n" + "=" * 60) print("šŸ† TEST SUMMARY") print("=" * 60) # Count passed tests correctly main_tests_passed = sum(1 for k, v in test_results.items() if k != 'ordering_strategies' and v) ordering_tests_passed = sum(test_results['ordering_strategies'].values()) total_passed = main_tests_passed + ordering_tests_passed main_tests_total = len(test_results) - 1 ordering_tests_total = len(test_results['ordering_strategies']) total_tests = main_tests_total + ordering_tests_total print(f"šŸ“Š Overall: {total_passed}/{total_tests} tests passed") print(f"šŸ’¾ Device: {device}") print(f"ā±ļø Total time: {time.time() - start_time:.2f}s") # Detailed results for test_name, result in test_results.items(): if test_name == 'ordering_strategies': print(f"šŸ”„ Ordering strategies:") for strategy, strategy_result in result.items(): status = "āœ…" if strategy_result else "āŒ" print(f" {status} {strategy}") else: status = "āœ…" if result else "āŒ" print(f"{status} {test_name.replace('_', ' ').title()}") # Performance summary if test_results['evaluation']: print(f"\nšŸŽÆ Final Performance:") print(f" Test Accuracy: {test_metrics['test_acc']:.4f} ({test_metrics['test_acc']*100:.2f}%)") print(f" Training Time: {training_time:.2f}s") print(f" Model Size: {total_params:,} parameters") print(f" Params per sample: {params_per_sample:.1f}") # Compare with baselines cora_baselines = { 'Random': 0.143, 'Simple': 0.300, 'GCN': 0.815, 'GAT': 0.830 } print(f"\nšŸ“ˆ Baseline Comparison (Cora):") for model_name, baseline in cora_baselines.items(): diff = test_metrics['test_acc'] - baseline if diff > 0: status = "🟢" desc = f"(+{diff:.3f} better)" elif diff > -0.1: status = "🟔" desc = f"({diff:.3f} competitive)" else: status = "šŸ”“" desc = f"({diff:.3f} gap)" print(f" {status} {model_name:12}: {baseline:.3f} {desc}") # Overfitting analysis if trainer.best_gap < 0.1: print(f"\nšŸŽ‰ Excellent generalization! (gap: {trainer.best_gap:.3f})") elif trainer.best_gap < 0.2: print(f"\nšŸ‘ Good generalization (gap: {trainer.best_gap:.3f})") else: print(f"\nāš ļø Some overfitting detected (gap: {trainer.best_gap:.3f})") print(f"\n✨ All tests completed!") if total_passed == total_tests: print(f"šŸŽ‰ Perfect score! Regularized system working well!") elif total_passed >= total_tests * 0.8: print(f"šŸ‘ Great! System is mostly functional.") else: print(f"āš ļø Some issues detected.") return test_results if __name__ == "__main__": results = run_comprehensive_test() # Exit with appropriate code main_tests_passed = sum(1 for k, v in results.items() if k != 'ordering_strategies' and v) ordering_tests_passed = sum(results['ordering_strategies'].values()) total_passed = main_tests_passed + ordering_tests_passed main_tests_total = len(results) - 1 ordering_tests_total = len(results['ordering_strategies']) total_tests = main_tests_total + ordering_tests_total if total_passed == total_tests: exit(0) else: exit(1)