import gradio as gr
import os
from concurrent.futures import ThreadPoolExecutor
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever

# Environment variable for API token
api_token = os.getenv("API_TOKEN")
print(f"API Token loaded: {api_token[:5]}...")  # Debug
if not api_token:
    raise ValueError("Environment variable 'FirstToken' not set.")

# Available LLM models
list_llm = [
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "deepseek-ai/deepseek-llm-7b-chat"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# -----------------------------------------------------------------------------
# Document Loading and Splitting (Optimized with Threading)
# -----------------------------------------------------------------------------
def load_single_pdf(file_path):
    """Load a single PDF file."""
    loader = PyPDFLoader(file_path)
    return loader.load()

def load_doc(list_file_path, progress=gr.Progress()):
    """Load and split PDF documents into chunks with multi-threading."""
    if not list_file_path:
        raise ValueError("No files provided for processing.")
    
    # Use ThreadPoolExecutor to parallelize PDF loading
    with ThreadPoolExecutor() as executor:
        pages = list(executor.map(load_single_pdf, list_file_path))
        pages = [page for sublist in pages for page in sublist]  # Flatten list
    
    progress(0.5, "Splitting documents...")
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)  # Increased chunk size
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# -----------------------------------------------------------------------------
# Vector Database Creation (Optimized with Lightweight Embeddings)
# -----------------------------------------------------------------------------
def create_chromadb(splits, persist_directory="chroma_db", progress=gr.Progress()):
    """Create ChromaDB vector database with optimized embeddings."""
    # Use a lighter embedding model
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    progress(0.7, "Creating vector database...")
    chromadb = Chroma.from_documents(
        documents=splits,
        embedding=embeddings,
        persist_directory=persist_directory
    )
    return chromadb

# -----------------------------------------------------------------------------
# Retrievers
# -----------------------------------------------------------------------------
def create_bm25_retriever(splits):
    """Create BM25 retriever from document splits."""
    retriever = BM25Retriever.from_documents(splits)
    retriever.k = 2  # Reduced to 2 documents for faster retrieval
    return retriever

def create_ensemble_retriever(vector_db, bm25_retriever):
    """Create an ensemble retriever."""
    return EnsembleRetriever(
        retrievers=[vector_db.as_retriever(search_kwargs={"k": 2}), bm25_retriever],  # Limit to 2 docs
        weights=[0.7, 0.3]
    )

# -----------------------------------------------------------------------------
# Initialize Database
# -----------------------------------------------------------------------------
def initialize_database(list_file_obj, progress=gr.Progress()):
    """Initialize the document database with error handling."""
    try:
        list_file_path = [x.name for x in list_file_obj if x is not None]
        progress(0.1, "Loading documents...")
        doc_splits = load_doc(list_file_path, progress)
        chromadb = create_chromadb(doc_splits, progress=progress)
        bm25_retriever = create_bm25_retriever(doc_splits)
        ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
        progress(1.0, "Database creation complete!")
        return ensemble_retriever, "Database created successfully!"
    except Exception as e:
        return None, f"Error initializing database: {str(e)}"

# -----------------------------------------------------------------------------
# Initialize LLM Chain
# -----------------------------------------------------------------------------
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
    """Initialize the language model chain."""
    if retriever is None:
        raise ValueError("Retriever is None. Please process documents first.")
    
    try:
        print(f"Initializing LLM: {llm_model} with token: {api_token[:5]}...")
        llm = HuggingFaceEndpoint(
            repo_id=llm_model,
            huggingfacehub_api_token=api_token,
            temperature=temperature,
            max_new_tokens=max_tokens,
            top_k=top_k,
            task="text-generation"
        )
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key="answer",
            return_messages=True
        )
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            chain_type="stuff",
            memory=memory,
            return_source_documents=True,
            verbose=False
        )
        return qa_chain
    except Exception as e:
        raise RuntimeError(f"Failed to initialize LLM chain: {str(e)}")

# -----------------------------------------------------------------------------
# Initialize LLM
# -----------------------------------------------------------------------------
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()):
    """Initialize the Language Model."""
    if retriever is None:
        return None, "Error: No database initialized. Please process documents first."
    
    try:
        llm_name = list_llm[llm_option]
        print(f"Selected LLM model: {llm_name}")
        qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, retriever)
        return qa_chain, "Analysis Assistant initialized and ready!"
    except Exception as e:
        return None, f"Error initializing LLM: {str(e)}"

# -----------------------------------------------------------------------------
# Chat History Formatting
# -----------------------------------------------------------------------------
def format_chat_history(message, chat_history):
    """Format chat history for the model."""
    return [f"User: {user_msg}\nAssistant: {bot_msg}" for user_msg, bot_msg in chat_history]

# -----------------------------------------------------------------------------
# Conversation Function
# -----------------------------------------------------------------------------
def conversation(qa_chain, message, history, lang):
    """Handle conversation and document analysis."""
    if not qa_chain:
        return None, gr.update(value="Assistant not initialized"), history, "", 0, "", 0, "", 0

    lang_instruction = " (Responda em Português)" if lang == "pt" else " (Respond in English)"
    query = message + lang_instruction
    
    try:
        formatted_chat_history = format_chat_history(message, history)
        response = qa_chain.invoke({"question": query, "chat_history": formatted_chat_history})
        answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]

        sources = response["source_documents"]
        source_data = [("Unknown", 0)] * 3
        for i, doc in enumerate(sources[:3]):
            source_data[i] = (doc.page_content.strip(), doc.metadata["page"] + 1)

        new_history = history + [(message, answer)]
        return (
            qa_chain, gr.update(value=""), new_history,
            source_data[0][0], source_data[0][1],
            source_data[1][0], source_data[1][1],
            source_data[2][0], source_data[2][1]
        )
    except Exception as e:
        return qa_chain, gr.update(value=f"Error: {str(e)}"), history, "", 0, "", 0, "", 0

# -----------------------------------------------------------------------------
# Gradio Demo
# -----------------------------------------------------------------------------
def demo():
    """Main demo application with enhanced layout."""
    theme = gr.themes.Default(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate")
    custom_css = """
        .container {background: #ffffff; padding: 1rem; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1);}
        .header {text-align: center; margin-bottom: 2rem;}
        .header h1 {color: #1a365d; font-size: 2.5rem; margin-bottom: 0.5rem;}
        .section {margin-bottom: 1.5rem; padding: 1rem; background: #f8fafc; border-radius: 8px;}
    """

    with gr.Blocks(theme=theme, css=custom_css) as demo:
        retriever = gr.State()
        qa_chain = gr.State()
        language = gr.State(value="en")

        gr.HTML(
            '<div class="header"><h1>MetroAssist AI</h1><p>Expert System for Metrology Report Analysis</p></div>'
        )

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("## Document Processing")
                with gr.Column(elem_classes="section"):
                    document = gr.Files(label="Metrology Reports (PDF)", file_count="multiple", file_types=["pdf"])
                    db_btn = gr.Button("Process Documents")
                    db_progress = gr.Textbox(value="Ready for documents", label="Processing Status")

                gr.Markdown("## Model Configuration")
                with gr.Column(elem_classes="section"):
                    llm_btn = gr.Radio(choices=list_llm_simple, label="Select AI Model", value=list_llm_simple[0], type="index")
                    language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English")
                    with gr.Accordion("Advanced Settings", open=False):
                        slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
                        slider_maxtokens = gr.Slider(128, 2048, value=1024, step=128, label="Response Length")  # Reduced max_tokens
                        slider_topk = gr.Slider(1, 5, value=3, step=1, label="Analysis Diversity")  # Reduced range
                    qachain_btn = gr.Button("Initialize Assistant", interactive=False)
                    llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")

            with gr.Column(scale=2):
                gr.Markdown("## Interactive Analysis")
                chatbot = gr.Chatbot(height=400, label="Analysis Conversation")
                with gr.Row():
                    msg = gr.Textbox(placeholder="Ask about your metrology report...", label="Query")
                    submit_btn = gr.Button("Send")
                    clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
                with gr.Accordion("Document References", open=False):
                    with gr.Row():
                        doc_source1, source1_page = gr.Textbox(label="Reference 1", lines=2), gr.Number(label="Page")
                        doc_source2, source2_page = gr.Textbox(label="Reference 2", lines=2), gr.Number(label="Page")
                        doc_source3, source3_page = gr.Textbox(label="Reference 3", lines=2), gr.Number(label="Page")

        # Event Handlers
        language_btn.change(lambda x: "en" if x == "English" else "pt", inputs=language_btn, outputs=language)

        def enable_qachain_btn(retriever, status):
            return gr.update(interactive=retriever is not None and "successfully" in status)

        db_btn.click(
            initialize_database,
            inputs=[document],
            outputs=[retriever, db_progress]
        ).then(
            enable_qachain_btn,
            inputs=[retriever, db_progress],
            outputs=[qachain_btn]
        )

        qachain_btn.click(
            initialize_LLM,
            inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever],
            outputs=[qa_chain, llm_progress]
        )

        submit_btn.click(
            conversation,
            inputs=[qa_chain, msg, chatbot, language],
            outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page]
        )
        msg.submit(
            conversation,
            inputs=[qa_chain, msg, chatbot, language],
            outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page]
        )

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()