# database.py
import chromadb
from parser import parse_python_code, create_vector
import os
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModel
import torch
from dotenv import load_dotenv
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

# User-configurable variables
DB_NAME = "python_programs"  # ChromaDB collection name
HF_DATASET_NAME = "python_program_vectors"  # Hugging Face Dataset name
PERSIST_DIR = "./chroma_data"  # Directory for persistent storage (optional)
USE_GPU = False  # Default to CPU, set to True for GPU if available

def init_chromadb(persist_dir=PERSIST_DIR):
    """Initialize ChromaDB client, optionally with persistent storage, with error handling and logging."""
    try:
        # Use persistent storage if directory exists, otherwise in-memory
        if os.path.exists(persist_dir):
            logger.info(f"Initializing ChromaDB with persistent storage at {persist_dir}")
            client = chromadb.PersistentClient(path=persist_dir)
        else:
            logger.info("Initializing ChromaDB with in-memory storage")
            client = chromadb.Client()
        return client
    except Exception as e:
        logger.error(f"Error initializing ChromaDB: {e}")
        raise

def create_collection(client, collection_name=DB_NAME):
    """Create or get a ChromaDB collection for Python programs, with error handling and logging."""
    try:
        collection = client.get_or_create_collection(name=collection_name)
        logger.info(f"Using ChromaDB collection: {collection_name}, contains {collection.count()} entries")
        if collection is None or not hasattr(collection, 'add'):
            raise ValueError("ChromaDB collection creation or access failed")
        return collection
    except Exception as e:
        logger.error(f"Error creating or getting collection {collection_name}: {e}")
        raise

def store_program(client, code, sequence, vectors, collection_name=DB_NAME):
    """Store a program in ChromaDB with its code, sequence, and vectors, with error handling and logging."""
    try:
        collection = create_collection(client, collection_name)
        
        # Flatten vectors to ensure they are a list of numbers (ChromaDB expects flat embeddings)
        # Use the first vector (semantic or program vector) for ChromaDB embedding, ensuring 6D
        flattened_vectors = vectors[0] if vectors and len(vectors) > 0 and len(vectors[0]) == 6 else [0] * 6
        
        # Store program data (ID, code, sequence, vectors)
        program_id = str(hash(code))  # Use hash of code as ID for uniqueness
        collection.add(
            documents=[code],
            metadatas=[{"sequence": ",".join(sequence), "description_tokens": " ".join(generate_description_tokens(sequence, vectors)), "program_vectors": str(vectors)}],
            ids=[program_id],
            embeddings=[flattened_vectors]  # Pass as 6D vector
        )
        logger.info(f"Stored program in ChromaDB: {program_id}, total entries: {collection.count()}")
        return program_id
    except Exception as e:
        logger.error(f"Error storing program in ChromaDB: {e}")
        raise

def populate_sample_db(client):
    """Populate ChromaDB with sample Python programs, with logging."""
    try:
        samples = [
            """
            import os
            def add_one(x):
                y = x + 1
                return y
            """,
            """
            def multiply(a, b):
                c = a * b
                if c > 0:
                    return c
            """
        ]
        
        for code in samples:
            parts, sequence = parse_python_code(code)
            vectors = [part['vector'] for part in parts]
            store_program(client, code, sequence, vectors)
        collection = create_collection(client, DB_NAME)
        logger.info(f"Populated ChromaDB with sample programs, total entries: {collection.count()}")
    except Exception as e:
        logger.error(f"Error populating sample database: {e}")
        raise

def query_programs(client, operations, collection_name=DB_NAME, top_k=5, semantic_query=None):
    """Query ChromaDB for programs matching the operations sequence or semantic description, with error handling and logging."""
    try:
        collection = create_collection(client, collection_name)
        
        if semantic_query:
            # Semantic search using a 6D vector generated from the description
            query_vector = generate_semantic_vector(semantic_query)
            results = collection.query(
                query_embeddings=[query_vector],
                n_results=top_k,
                include=["documents", "metadatas"]
            )
        else:
            # Vector-based search for operations sequence
            query_vector = sum([create_vector(op, 0, (1, 1), 100, []) for op in operations], []) / len(operations) if operations else [0] * 6
            results = collection.query(
                query_embeddings=[query_vector],
                n_results=top_k,
                include=["documents", "metadatas"]
            )
        
        # Process results
        matching_programs = []
        for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
            sequence = meta['sequence'].split(',')
            if not semantic_query or is_subsequence(operations, sequence):  # Ensure sequence match for operations
                try:
                    # Reconstruct program vectors (flatten if needed)
                    doc_vectors = eval(meta['program_vectors']) if isinstance(meta['program_vectors'], str) else meta['program_vectors']
                    if isinstance(doc_vectors, (list, np.ndarray)) and len(doc_vectors) == 6:
                        program_vector = doc_vectors  # Single flat vector
                    else:
                        program_vector = np.mean([v for v in doc_vectors if isinstance(v, (list, np.ndarray))], axis=0).tolist()
                except:
                    program_vector = [0] * 6  # Fallback for malformed vectors
                # Use the semantic embedding for similarity
                semantic_vector = eval(doc['vectors']) if isinstance(doc['vectors'], str) else doc['vectors']
                similarity = cosine_similarity([query_vector], [semantic_vector])[0][0] if semantic_vector and query_vector else 0
                matching_programs.append({'id': meta['id'], 'code': doc, 'similarity': similarity, 'description': meta.get('description_tokens', ''), 'program_vectors': meta.get('program_vectors', '[]')})
        
        logger.info(f"Queried {len(matching_programs)} programs from ChromaDB, total entries: {collection.count()}")
        return sorted(matching_programs, key=lambda x: x['similarity'], reverse=True)
    except Exception as e:
        logger.error(f"Error querying programs from ChromaDB: {e}")
        raise

def create_vector(category, level, location, total_lines, parent_path):
    """Helper to create a vector for query (matches parser's create_vector)."""
    category_map = {
        'import': 1, 'function': 2, 'async_function': 3, 'class': 4,
        'if': 5, 'while': 6, 'for': 7, 'try': 8, 'expression': 9, 'spacer': 10,
        'other': 11, 'elif': 12, 'else': 13, 'except': 14, 'finally': 15, 'return': 16,
        'assigned_variable': 17, 'input_variable': 18, 'returned_variable': 19
    }
    category_id = category_map.get(category, 0)
    start_line, end_line = location
    span = (end_line - start_line + 1) / total_lines
    center_pos = ((start_line + end_line) / 2) / total_lines
    parent_depth = len(parent_path)
    parent_weight = sum(category_map.get(parent.split('[')[0].lower(), 0) * (1 / (i + 1)) 
                        for i, parent in enumerate(parent_path)) / max(1, len(category_map))
    return [category_id, level, center_pos, span, parent_depth, parent_weight]

def is_subsequence(subseq, seq):
    """Check if subseq is a subsequence of seq."""
    it = iter(seq)
    return all(item in it for item in subseq)

def generate_description_tokens(sequence, vectors):
    """Generate semantic description tokens for a program based on its sequence and vectors."""
    tokens = []
    category_descriptions = {
        'import': 'imports module',
        'function': 'defines function',
        'assigned_variable': 'assigns variable',
        'input_variable': 'input parameter',
        'returned_variable': 'returns value',
        'if': 'conditional statement',
        'return': 'returns result',
        'try': 'try block',
        'except': 'exception handler',
        'expression': 'expression statement',
        'spacer': 'empty line or comment'
    }
    
    for cat, vec in zip(sequence, vectors):
        if cat in category_descriptions:
            tokens.append(f"{category_descriptions[cat]}:{cat}")
            # Add vector-derived features (e.g., level, span) as tokens
            tokens.append(f"level:{vec[1]}")
            tokens.append(f"span:{vec[3]:.2f}")
    return tokens

def generate_semantic_vector(description, total_lines=100, use_gpu=False):
    """Generate a 6D semantic vector for a textual description using CodeBERT, projecting to 6D."""
    global tokenizer, model, device
    if tokenizer is None or model is None:
        tokenizer, model, device = load_codebert_model(use_gpu)
    
    # Tokenize and encode the description
    inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        # Use mean pooling of the last hidden states
        vector = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().tolist()
    
    # Truncate or project to 6D (simplified projection: take first 6 dimensions)
    if len(vector) < 6:
        vector.extend([0] * (6 - len(vector)))
    elif len(vector) > 6:
        vector = vector[:6]  # Truncate to 6D
    
    # Ensure vector isn’t all zeros or defaults
    if all(v == 0 for v in vector):
        logger.warning(f"Default vector detected for description: {description}")
        # Fallback: Use heuristic if CodeBERT fails to generate meaningful embeddings
        category_map = {
            'import': 1, 'function': 2, 'assign': 17, 'input': 18, 'return': 19, 'if': 5, 'try': 8, 'except': 14
        }
        tokens = description.lower().split()
        vector = [0] * 6
        for token in tokens:
            for cat, cat_id in category_map.items():
                if cat in token:
                    vector[0] = cat_id  # category_id
                    vector[1] = 1  # level
                    vector[2] = 0.5  # center_pos
                    vector[3] = 0.1  # span
                    vector[4] = 1  # parent_depth
                    vector[5] = cat_id / len(category_map)  # parent_weight
                    break
    
    logger.debug(f"Generated semantic vector for '{description}': {vector}")
    return vector

def save_chromadb_to_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
    """Save ChromaDB data to Hugging Face Dataset, with error handling and logging."""
    try:
        client = init_chromadb()
        collection = client.get_collection(DB_NAME)
        
        # Fetch all data from ChromaDB
        results = collection.get(include=["documents", "metadatas", "embeddings"])
        data = {
            "code": results["documents"],
            "sequence": [meta["sequence"] for meta in results["metadatas"]],
            "vectors": results["embeddings"],  # Semantic 6D vectors
            "description_tokens": [meta.get('description_tokens', '') for meta in results["metadatas"]],
            "program_vectors": [eval(meta.get('program_vectors', '[]')) for meta in results["metadatas"]]  # Store structural vectors
        }
        
        # Create a Hugging Face Dataset
        dataset = Dataset.from_dict(data)
        logger.info(f"Created Hugging Face Dataset with {len(data['code'])} entries")
        
        # Push to Hugging Face Hub, overwriting existing dataset
        dataset.push_to_hub(dataset_name, token=token, exists_ok=True)
        logger.info(f"Dataset pushed to Hugging Face Hub as {dataset_name}, overwriting existing dataset")
        # Verify push (optional, could check dataset on Hub)
        logger.info(f"Verified Hugging Face dataset push with {len(dataset)} entries")
    except Exception as e:
        logger.error(f"Error pushing dataset to Hugging Face Hub: {e}")
        raise

def load_chromadb_from_hf(dataset_name=HF_DATASET_NAME, token=os.getenv("HF_KEY")):
    """Load ChromaDB data from Hugging Face Dataset, handle empty dataset, with error handling and logging."""
    try:
        dataset = load_dataset(dataset_name, split="train", token=token)
        client = init_chromadb()
        collection = create_collection(client)
        
        for item in dataset:
            store_program(client, item["code"], item["sequence"].split(','), item["program_vectors"])
        collection = create_collection(client, DB_NAME)
        logger.info(f"Loaded {len(dataset)} entries from Hugging Face Hub into ChromaDB, total entries: {collection.count()}")
        return client
    except Exception as e:
        logger.error(f"Error loading dataset from Hugging Face: {e}")
        # Fallback: Create empty collection
        client = init_chromadb()
        collection = create_collection(client)
        logger.info(f"Created empty ChromaDB collection: {DB_NAME}, contains {collection.count()} entries")
        return client

if __name__ == '__main__':
    client = load_chromadb_from_hf()
    collection = create_collection(client, DB_NAME)
    logger.info(f"Database initialized or loaded from Hugging Face Hub, contains {collection.count()} entries")