import asyncio
import json
import logging
import os
import pathlib
import time
import uuid
from aiohttp import web, WSMsgType
from typing import Dict, Any

from api_core import VideoGenerationAPI
from api_session import SessionManager
from api_metrics import MetricsTracker
from api_config import *

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Create global session and metrics managers
session_manager = SessionManager()
metrics_tracker = MetricsTracker()

# Dictionary to track connected anonymous clients by IP address
anon_connections = {}
anon_connection_lock = asyncio.Lock()

async def status_handler(request: web.Request) -> web.Response:
    """Handler for API status endpoint"""
    api = session_manager.shared_api
    
    # Get current busy status of all endpoints
    endpoint_statuses = []
    for ep in api.endpoint_manager.endpoints:
        endpoint_statuses.append({
            'id': ep.id,
            'url': ep.url,
            'busy': ep.busy,
            'last_used': ep.last_used,
            'error_count': ep.error_count,
            'error_until': ep.error_until
        })
    
    # Get session statistics
    session_stats = session_manager.get_session_stats()
    
    # Get metrics
    api_metrics = metrics_tracker.get_metrics()
    
    return web.json_response({
        'product': PRODUCT_NAME,
        'version': PRODUCT_VERSION,
        'maintenance_mode': MAINTENANCE_MODE,
        'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
        'endpoint_status': endpoint_statuses,
        'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())),
        'active_sessions': session_stats,
        'metrics': api_metrics
    })

async def metrics_handler(request: web.Request) -> web.Response:
    """Handler for detailed metrics endpoint (protected)"""
    # Check for API key in header or query param
    auth_header = request.headers.get('Authorization', '')
    api_key = None
    
    if auth_header.startswith('Bearer '):
        api_key = auth_header[7:]
    else:
        api_key = request.query.get('key')
    
    # Validate API key (using SECRET_TOKEN as the API key)
    if not api_key or api_key != SECRET_TOKEN:
        return web.json_response({
            'error': 'Unauthorized'
        }, status=401)
    
    # Get detailed metrics
    detailed_metrics = metrics_tracker.get_detailed_metrics()
    
    return web.json_response(detailed_metrics)

async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
    # Check if maintenance mode is enabled
    if MAINTENANCE_MODE:
        # Return an error response indicating maintenance mode
        return web.json_response({
            'error': 'Server is in maintenance mode',
            'maintenance': True
        }, status=503)  # 503 Service Unavailable
    
    ws = web.WebSocketResponse(
        max_msg_size=1024*1024*20,  # 20MB max message size
        timeout=30.0  # we want to keep things tight and short
    )
    
    await ws.prepare(request)
    
    # Get the Hugging Face token from query parameters
    hf_token = request.query.get('hf_token', '')
    
    # Generate a unique user ID for this connection
    user_id = str(uuid.uuid4())
    
    # Validate the token and determine the user role
    user_role = await session_manager.shared_api.validate_user_token(hf_token)
    logger.info(f"User {user_id} connected with role: {user_role}")
    
    # Get client IP address
    peername = request.transport.get_extra_info('peername')
    if peername is not None:
        client_ip = peername[0]
    else:
        client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
    
    logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}")
    
    # Check for anonymous user connection limits
    if user_role == 'anon':
        async with anon_connection_lock:
            # Track this connection
            anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1
            # Store the IP so we can clean up later
            ws.client_ip = client_ip
            
            # Log multiple connections from same IP but don't restrict them
            if anon_connections[client_ip] > 1:
                logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections")
    
    # Store the user role in the websocket for easy access
    ws.user_role = user_role
    ws.user_id = user_id
    
    # Register with metrics
    metrics_tracker.register_session(user_id, client_ip)
    
    # Create a new session for this user
    user_session = await session_manager.create_session(user_id, user_role, ws)

    try:
        async for msg in ws:
            if msg.type == WSMsgType.TEXT:
                try:
                    data = json.loads(msg.data)
                    action = data.get('action')
                    
                    # Check for rate limiting
                    request_type = 'other'
                    if action in ['join_chat', 'leave_chat', 'chat_message']:
                        request_type = 'chat'
                    elif action in ['generate_video']:
                        request_type = 'video'
                    elif action == 'search':
                        request_type = 'search'
                    elif action == 'simulate':
                        request_type = 'simulation'
                    
                    # Record the request for metrics
                    await metrics_tracker.record_request(user_id, client_ip, request_type, user_role)
                    
                    # Check rate limits (except for admins)
                    if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role):
                        await ws.send_json({
                            'action': action,
                            'requestId': data.get('requestId'),
                            'success': False,
                            'error': f'Rate limit exceeded for {request_type} requests. Please try again later.'
                        })
                        continue
                    
                    # Route requests to appropriate queues
                    if action in ['join_chat', 'leave_chat', 'chat_message']:
                        await user_session.chat_queue.put(data)
                    elif action in ['generate_video']:
                        await user_session.video_queue.put(data)
                    elif action == 'search':
                        await user_session.search_queue.put(data)
                    elif action == 'simulate':
                        await user_session.simulation_queue.put(data)
                    else:
                        await user_session.process_generic_request(data)
                        
                except Exception as e:
                    logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
                    await ws.send_json({
                        'action': data.get('action') if 'data' in locals() else 'unknown',
                        'success': False,
                        'error': f'Error processing message: {str(e)}'
                    })
                    
            elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
                break
                
    finally:
        # Cleanup session
        await session_manager.delete_session(user_id)
        
        # Cleanup anonymous connection tracking
        if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'):
            client_ip = ws.client_ip
            async with anon_connection_lock:
                if client_ip in anon_connections:
                    anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1)
                    if anon_connections[client_ip] == 0:
                        del anon_connections[client_ip]
                    logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}")
        
        # Unregister from metrics
        metrics_tracker.unregister_session(user_id, client_ip)
        logger.info(f"Connection closed for user {user_id}")
    
    return ws

async def init_app() -> web.Application:
    app = web.Application(
        client_max_size=1024**2*20  # 20MB max size
    )
    
    # Add cleanup logic
    async def cleanup(app):
        logger.info("Shutting down server, closing all sessions...")
        await session_manager.close_all_sessions()
    
    app.on_shutdown.append(cleanup)
    
    # Add routes
    app.router.add_get('/ws', websocket_handler)
    app.router.add_get('/api/status', status_handler)
    app.router.add_get('/api/metrics', metrics_handler)
    
    # Set up static file serving
    # Define the path to the public directory
    public_path = pathlib.Path(__file__).parent / 'build' / 'web'
    if not public_path.exists():
        public_path.mkdir(parents=True, exist_ok=True)
    
    # Set up static file serving with proper security considerations
    async def static_file_handler(request):
        # Get the path from the request (removing leading /)
        path_parts = request.path.lstrip('/').split('/')
        
        # Convert to safe path to prevent path traversal attacks
        safe_path = public_path.joinpath(*path_parts)
        
        # Make sure the path is within the public directory (prevent directory traversal)
        try:
            safe_path = safe_path.resolve()
            if not str(safe_path).startswith(str(public_path.resolve())):
                return web.HTTPForbidden(text="Access denied")
        except (ValueError, FileNotFoundError):
            return web.HTTPNotFound()
        
        # If path is a directory, look for index.html
        if safe_path.is_dir():
            safe_path = safe_path / 'index.html'
        
        # Check if the file exists
        if not safe_path.exists() or not safe_path.is_file():
            # If not found, serve index.html (for SPA routing)
            safe_path = public_path / 'index.html'
            if not safe_path.exists():
                return web.HTTPNotFound()
        
        # Determine content type based on file extension
        content_type = 'text/plain'
        ext = safe_path.suffix.lower()
        if ext == '.html':
            content_type = 'text/html'
        elif ext == '.js':
            content_type = 'application/javascript'
        elif ext == '.css':
            content_type = 'text/css'
        elif ext in ('.jpg', '.jpeg'):
            content_type = 'image/jpeg'
        elif ext == '.png':
            content_type = 'image/png'
        elif ext == '.gif':
            content_type = 'image/gif'
        elif ext == '.svg':
            content_type = 'image/svg+xml'
        elif ext == '.json':
            content_type = 'application/json'
        
        # Return the file with appropriate headers
        return web.FileResponse(safe_path, headers={'Content-Type': content_type})
    
    # Add catch-all route for static files (lower priority than API routes)
    app.router.add_get('/{path:.*}', static_file_handler)
    
    return app

if __name__ == '__main__':
    app = asyncio.run(init_app())
    web.run_app(app, host='0.0.0.0', port=8080)