import pandas as pd
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import chromadb
from chromadb.utils import embedding_functions
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import gradio as gr

import re 

#######################################################

# Load the email dataset
emails = pd.read_csv("./cleaned_data.csv")

######################################################
client = chromadb.PersistentClient(path="./content")

# Create a ChromaDB client
client = chromadb.Client()
collection = client.create_collection("enron_emails")

# Add documents and IDs to the collection, using ChromaDB's built-in text encoding
collection.add(
    documents=emails["body"].tolist()[:10000],
    ids=emails["file"].tolist()[:10000],
    metadatas=[{"source": "enron_emails"}] * len(emails[:10000]),  # Optional metadata
)


####################################################
# Load model directly
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the trained model
model = AutoModelForSeq2SeqLM.from_pretrained("varl42/modello42")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("varl42/modello42")

##################################################

# Load the ChromaDB collection
client = chromadb.Client()
collection = client.get_collection("enron_emails")

##################################################

def query_collection(query_text):
    try:
        # Perform the query
        response = collection.query(query_texts=[query_text], n_results=2)

        # Extract documents from the response
        if 'documents' in response and len(response['documents']) > 0:
            # Assuming each query only has one set of responses, hence response['documents'][0]
            documents = response['documents'][0]  # This gets the first (and possibly only) list of documents
            return "\n\n".join(documents)
        else:
            # Handle cases where no documents are found or the structure is unexpected
            return "No documents found or the response structure is not as expected."
    except Exception as e:
        return f"An error occurred while querying: {e}"


def summarize_documents(text_input):
    try:
        # Tokenize input text for the model
        inputs = tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
        # Generate a summary with the model
        summary_ids = model.generate(inputs['input_ids'], max_length=512, min_length=125, length_penalty=2.0, num_beams=4, early_stopping=True)
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        summary = re.sub(r"(\w+)([?!])\s", r"\1\2. ", summary)  # Ensures that sentences ending in ? ! .
        summary = re.sub(r"([^.?!])(?=\s+[A-Z]|$)", r"\1.", summary) 
        
        return summary
    except Exception as e:
        return f"An error occurred while summarizing: {e}"

def query_then_summarize(query_text, _):
    try:
        # Perform the query
        query_results = query_collection(query_text)
        # Return empty summary initially
        return query_results, ""
    except Exception as e:
        return f"An error occurred: {e}", ""

def summarize_from_query(_, query_results):
    try:
        # Use the query results for summarization
        summary = summarize_documents(query_results)
        return query_results, summary
    except Exception as e:
        return query_results, f"An error occurred while summarizing: {e}"


###################################################
        
# Setup the Gradio interface
with gr.Blocks() as app:
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query")
        query_button = gr.Button("Query")
    query_results = gr.Text(label="Query Results", placeholder="Query results will appear here...", interactive=True)
    summarize_button = gr.Button("Summarize")
    summary_output = gr.Textbox(label="Summary", placeholder="Summary will appear here...")

    query_button.click(query_then_summarize, inputs=[query_input, query_results], outputs=[query_results, summary_output])
    summarize_button.click(summarize_from_query, inputs=[query_button, query_results], outputs=[query_results, summary_output])

app.launch()