from climateqa.engine.embeddings import get_embeddings_function
embeddings_function = get_embeddings_function()

from climateqa.knowledge.openalex import OpenAlex
from sentence_transformers import CrossEncoder

# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
oa = OpenAlex()

import gradio as gr
import pandas as pd
import numpy as np
import os
import time
import re
import json

from gradio import ChatMessage

# from gradio_modal import Modal

from io import BytesIO
import base64

from datetime import datetime
from azure.storage.fileshare import ShareServiceClient

from utils import create_user_id

from langchain_chroma import Chroma
from collections import defaultdict
from gradio_modal import Modal



# ClimateQ&A imports
from climateqa.engine.llm import get_llm
from climateqa.engine.vectorstore import get_pinecone_vectorstore
# from climateqa.knowledge.retriever import ClimateQARetriever
from climateqa.engine.reranker import get_reranker
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.chains.prompts import audience_prompts
from climateqa.sample_questions import QUESTIONS
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
from climateqa.utils import get_image_from_azure_blob_storage
from climateqa.engine.keywords import make_keywords_chain
# from climateqa.engine.chains.answer_rag import make_rag_papers_chain
from climateqa.engine.graph import make_graph_agent,display_graph
from climateqa.engine.embeddings import get_embeddings_function

from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs

from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox

# Load environment variables in local mode
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception as e:
    pass


# Set up Gradio Theme
theme = gr.themes.Base(
    primary_hue="blue",
    secondary_hue="red",
    font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)



init_prompt = ""

system_template = {
    "role": "system",
    "content": init_prompt,
}

account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
    account_key += "=="

credential = {
    "account_key": account_key,
    "account_name": os.environ["BLOB_ACCOUNT_NAME"],
}

account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)

user_id = create_user_id()



# Create vectorstore and retriever
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="title")

llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
reranker = get_reranker("nano")
# agent = make_graph_agent(llm,vectorstore,reranker)

# agent = make_graph_agent(llm,vectorstore,reranker)
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)

async def chat(query,history,audience,sources,reports,current_graphs):
    """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
    (messages in gradio format, messages in langchain format, source documents)"""

    date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f">> NEW QUESTION ({date_now}) : {query}")

    if audience == "Children":
        audience_prompt = audience_prompts["children"]
    elif audience == "General public":
        audience_prompt = audience_prompts["general"]
    elif audience == "Experts":
        audience_prompt = audience_prompts["experts"]
    else:
        audience_prompt = audience_prompts["experts"]

    # Prepare default values
    if sources is None or len(sources) == 0:
        sources = ["IPCC", "IPBES", "IPOS"]

    if reports is None or len(reports) == 0:
        reports = []
    
    inputs = {"user_input": query,"audience": audience_prompt,"sources":sources}
    result = agent.astream_events(inputs,version = "v1") 


    docs = []
    docs_used = True
    docs_html = ""
    output_query = ""
    output_language = ""
    output_keywords = ""
    gallery = []
    updates = []
    start_streaming = False
    graphs_html = ""    
    figures = '<div class="figures-container"><p></p> </div>'

    steps_display = {
        "categorize_intent":("🔄️ Analyzing user message",True),
        "transform_query":("🔄️ Thinking step by step to answer the question",True),
        "retrieve_documents":("🔄️ Searching in the knowledge base",False),
    }
    
    used_documents = []
    answer_message_content = ""
    try:
        async for event in result:
            if "langgraph_node" in event["metadata"]:
                node = event["metadata"]["langgraph_node"]

                if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
                    try:
                        docs = event["data"]["output"]["documents"]
                        docs_html = []
                        textual_docs = [d for d in docs if d.metadata["chunk_type"] == "text"]
                        for i, d in enumerate(textual_docs, 1):
                            if d.metadata["chunk_type"] == "text":
                                docs_html.append(make_html_source(d, i))
                        
                        used_documents = used_documents + [f"{d.metadata['short_name']} - {d.metadata['name']}" for d in docs]
                        history[-1].content = "Adding sources :\n\n - " + "\n - ".join(np.unique(used_documents))
                            
                        docs_html = "".join(docs_html)
                        
                    except Exception as e:
                        print(f"Error getting documents: {e}")
                        print(event)
 
                elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
                    event_description,display_output = steps_display[node]
                    if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
                        history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
 
                elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
                    if start_streaming == False:
                        start_streaming = True
                        history.append(ChatMessage(role="assistant", content = ""))
                    answer_message_content +=  event["data"]["chunk"].content
                    answer_message_content = parse_output_llm_with_sources(answer_message_content)
                    history[-1] = ChatMessage(role="assistant", content = answer_message_content)
                    # history.append(ChatMessage(role="assistant", content = new_message_content))

                elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
                    try:
                        recommended_content = event["data"]["output"]["recommended_content"]
                        
                        unique_graphs = []
                        seen_embeddings = set()

                        for x in recommended_content:
                            embedding = x.metadata["returned_content"]
                            
                            # Check if the embedding has already been seen
                            if embedding not in seen_embeddings:
                                unique_graphs.append({
                                    "embedding": embedding,
                                    "metadata": {
                                        "source": x.metadata["source"],
                                        "category": x.metadata["category"]
                                    }
                                })
                                # Add the embedding to the seen set
                                seen_embeddings.add(embedding)


                        categories = {}
                        for graph in unique_graphs:
                            category = graph['metadata']['category']
                            if category not in categories:
                                categories[category] = []
                            categories[category].append(graph['embedding'])

                        
                        for category, embeddings in categories.items():
                            graphs_html += f"<h3>{category}</h3>"
                            for embedding in embeddings:
                                graphs_html += f"<div>{embedding}</div>"
                                
                                
                    except Exception as e:
                        print(f"Error getting graphs: {e}")



                if event["name"] == "transform_query" and event["event"] =="on_chain_end":
                    if hasattr(history[-1],"content"):
                        history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]])
                        
                if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
                    print("X")
            
            yield history, docs_html, output_query, output_language, gallery, figures, graphs_html #,output_query,output_keywords
 
    except Exception as e:
        print(event, "has failed")
        raise gr.Error(f"{e}")


    try:
        # Log answer on Azure Blob Storage
        if os.getenv("GRADIO_ENV") != "local":
            timestamp = str(datetime.now().timestamp())
            file = timestamp + ".json"
            prompt = history[1]["content"]
            logs = {
                "user_id": str(user_id),
                "prompt": prompt,
                "query": prompt,
                "question":output_query,
                "sources":sources,
                "docs":serialize_docs(docs),
                "answer": history[-1].content,
                "time": timestamp,
            }
            log_on_azure(file, logs, share_client)
    except Exception as e:
        print(f"Error logging on Azure Blob Storage: {e}")
        raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")

        
    docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
    for i, doc in enumerate(docs_figures):
        if doc.metadata["chunk_type"] == "image":
            try:
                key = f"Image {i+1}"

                image_path = doc.metadata["image_path"].split("documents/")[1]
                img = get_image_from_azure_blob_storage(image_path)

                # Convert the image to a byte buffer
                buffered = BytesIO()
                img.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode()
                
                figures = figures + make_html_figure_sources(doc, i, img_str)  
                
                gallery.append(img)

            except Exception as e:
                print(f"Skipped adding image {i} because of {e}")
   
        
    

    yield history, docs_html, output_query, output_language, gallery, figures, graphs_html#,output_query,output_keywords


def save_feedback(feed: str, user_id):
    if len(feed) > 1:
        timestamp = str(datetime.now().timestamp())
        file = user_id + timestamp + ".json"
        logs = {
            "user_id": user_id,
            "feedback": feed,
            "time": timestamp,
        }
        log_on_azure(file, logs, share_client)
        return "Feedback submitted, thank you!"




def log_on_azure(file, logs, share_client):
    logs = json.dumps(logs)
    file_client = share_client.get_file_client(file)
    file_client.upload_file(logs)


def generate_keywords(query):
    chain = make_keywords_chain(llm)
    keywords = chain.invoke(query)
    keywords = " AND ".join(keywords["keywords"])
    return keywords



papers_cols_widths = {
    "doc":50,
    "id":100,
    "title":300,
    "doi":100,
    "publication_year":100,
    "abstract":500,
    "rerank_score":100,
    "is_oa":50,
}

papers_cols = list(papers_cols_widths.keys())
papers_cols_widths = list(papers_cols_widths.values())


# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------


init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.

❓ How to use
- **Language**: You can ask me your questions in any language. 
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.

⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*

🛈 Information
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.


What do you want to learn ?
"""


def vote(data: gr.LikeData):
    if data.liked:
        print(data.value)
    else:
        print(data)

def save_graph(saved_graphs_state, embedding, category):
    print(f"\nCategory:\n{saved_graphs_state}\n")
    if category not in saved_graphs_state:
        saved_graphs_state[category] = []
    if embedding not in saved_graphs_state[category]:
        saved_graphs_state[category].append(embedding)
    return saved_graphs_state, gr.Button("Graph Saved")



with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
    chat_completed_state = gr.State(0)
    current_graphs = gr.State([])
    saved_graphs = gr.State({})
    
    with gr.Tab("ClimateQ&A"):

        with gr.Row(elem_id="chatbot-row"):
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(
                    value = [ChatMessage(role="assistant", content=init_prompt)],
                    type = "messages",
                    show_copy_button=True,
                    show_label = False,
                    elem_id="chatbot",
                    layout = "panel",
                    avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
                    max_height="80vh",
                    height="100vh"
                )
                
                # bot.like(vote,None,None)



                with gr.Row(elem_id = "input-message"):
                    textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
                 

            with gr.Column(scale=2, variant="panel",elem_id = "right-panel"):


                with gr.Tabs() as tabs:
                    with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
                                        
                        examples_hidden = gr.Textbox(visible = False)
                        first_key = list(QUESTIONS.keys())[0]
                        dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")

                        samples = []
                        for i,key in enumerate(QUESTIONS.keys()):

                            examples_visible = True if i == 0 else False

                            with gr.Row(visible = examples_visible) as group_examples:

                                examples_questions = gr.Examples(
                                    QUESTIONS[key],
                                    [examples_hidden],
                                    examples_per_page=8,
                                    run_on_click=False,
                                    elem_id=f"examples{i}",
                                    api_name=f"examples{i}",
                                    # label = "Click on the example question or enter your own",
                                    # cache_examples=True,
                                )
                            
                            samples.append(group_examples)


                    with gr.Tab("Sources",elem_id = "tab-citations",id = 1) as tab_sources:
                        sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
                        docs_textbox = gr.State("")
                        

                        
                    # with Modal(visible = False) as config_modal:
                    with gr.Tab("Configuration",elem_id = "tab-config",id = 2) as tab_config:

                        gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")


                        dropdown_sources = gr.CheckboxGroup(
                            ["IPCC", "IPBES","IPOS"],
                            label="Select source",
                            value=["IPCC"],
                            interactive=True,
                        )

                        dropdown_reports = gr.Dropdown(
                            POSSIBLE_REPORTS,
                            label="Or select specific reports",
                            multiselect=True,
                            value=None,
                            interactive=True,
                        )

                        dropdown_audience = gr.Dropdown(
                            ["Children","General public","Experts"],
                            label="Select audience",
                            value="Experts",
                            interactive=True,
                        )

                        output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
                        output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
                    
                    with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
                        with Modal(visible=False, elem_id="modal_figure_galery") as modal:
                            gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
                            
                        show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)    
                        show_full_size_figures.click(lambda : Modal(visible=True),None,modal)

                        figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
                        
                        
                    with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content:
                        graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
                        current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])

                        # @gr.render(inputs=[current_graphs])
                        # def display_default_recommended(current_graphs):
                        #     if len(current_graphs)==0:
                        #         placeholder_message = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")

                        # @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
                        # def render_graphs(current_graph_list):
                        #     global saved_graphs
                        #     with gr.Column():
                        #         print(f"\ncurrent_graph_list:\n{current_graph_list}")
                        #         for (embedding, category) in current_graph_list:
                        #             graphs_placeholder = gr.HTML(embedding, elem_id="graphs-placeholder")
                        #             save_btn = gr.Button("Save Graph")
                        #             save_btn.click(
                        #                 save_graph,
                        #                 [saved_graphs, gr.State(embedding), gr.State(category)],
                        #                 [saved_graphs, save_btn]
                        #             )
                                    
                                                            
                        # # Display current_graphs
                        # with gr.Row():
                        #     for embedding in current_graphs:
                        #         with gr.Column():
                        #             gr.HTML(embedding, elem_id="graphs-placeholder")
                        #             save_btn = gr.Button("Save Graph")
                        #             save_btn.click(
                        #                 save_graph,
                        #                 [saved_graphs, gr.State(embedding)],
                        #                 [saved_graphs, save_btn]
                        #             )

                            






#---------------------------------------------------------------------------------------
# OTHER TABS
#---------------------------------------------------------------------------------------

    # with gr.Tab("Recommended content", elem_id="tab-recommended_content2") as recommended_content_tab2:
        
    #     @gr.render(inputs=[current_graphs])
    #     def display_default_recommended_head(current_graphs_list):
    #         if len(current_graphs_list)==0:
    #             gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")

    #     @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
    #     def render_graphs_head(current_graph_list):
    #         global saved_graphs

    #         category_dict = defaultdict(list)
    #         for (embedding, category) in current_graph_list:
    #             category_dict[category].append(embedding)
            
    #         for category in category_dict:
    #             with gr.Tab(category):
    #                 splits = [category_dict[category][i:i+3] for i in range(0, len(category_dict[category]), 3)]
    #                 for row in splits:
    #                     with gr.Row():
    #                         for embedding in row:
    #                             with gr.Column():
    #                                 gr.HTML(embedding, elem_id="graphs-placeholder")
    #                                 save_btn = gr.Button("Save Graph")
    #                                 save_btn.click(
    #                                     save_graph,
    #                                     [saved_graphs, gr.State(embedding), gr.State(category)],
    #                                     [saved_graphs, save_btn]
    #                                 )



    # with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs") as saved_graphs_tab:
        
    #     @gr.render(inputs=[saved_graphs])
    #     def display_default_save(saved):
    #         if len(saved)==0:
    #             gr.HTML("<h2>You have not saved any graphs yet</h2>")

    #     @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
    #     def view_saved_graphs(graphs_list):
    #         categories = [category for category in graphs_list] # graphs_list.keys()
    #         for category in categories:
    #             with gr.Tab(category):
    #                 splits = [graphs_list[category][i:i+3] for i in range(0, len(graphs_list[category]), 3)]
    #                 for row in splits:
    #                     with gr.Row():
    #                         for graph in row:
    #                             gr.HTML(graph, elem_id="graphs-placeholder")



    # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
    #     gallery_component = gr.Gallery(object_fit='cover')

    # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):

    #     with gr.Row():
    #         with gr.Column(scale=1):
    #             query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
    #             keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
    #             after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
    #             search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)

    #         with gr.Column(scale=7):

    #             with gr.Tab("Summary",elem_id="papers-summary-tab"):
    #                 papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")

    #             with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
    #                 papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)

    #             with gr.Tab("Citations network",elem_id="papers-network-tab"):
    #                 citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")

    # with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs", id=4) as saved_graphs_tab:
    #     @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
    #     def view_saved_graphs(graphs_list):
    #         for graph in graphs_list:
    #             gr.HTML(graph, elem_id="graphs-placeholder")
            
    with gr.Tab("About",elem_classes = "max-height other-tabs"):
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")


    def start_chat(query,history):
        # history = history + [(query,None)]
        # history = [tuple(x) for x in history]
        history = history + [ChatMessage(role="user", content=query)]
        return (gr.update(interactive = False),gr.update(selected=1),history)
    
    def finish_chat():
        return (gr.update(interactive = True,value = ""),gr.update(selected=3))


    def change_completion_status(current_state):
        current_state = 1 - current_state
        return current_state
    
    def update_sources_number_display(sources_textbox, figures_cards, current_graphs):
        sources_number = sources_textbox.count("<h2>")
        figures_number = figures_cards.count("<h2>")
        graphs_number = current_graphs.count("<iframe")
        sources_notif_label = f"Sources ({sources_number})"
        figures_notif_label = f"Figures ({figures_number})"
        graphs_notif_label = f"Recommended content ({graphs_number})"
        # sources_notif_label = f"Sources (🆕)"
        # figures_notif_label = f"Figures (🆕)"
        # graphs_notif_label = f"Recommended content (🆕)"
         
        return gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label)
    
    (textbox
        .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
        .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
        .then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_textbox")
        .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content] )
    )

    (examples_hidden
        .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
        .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, figures_cards, current_graphs],concurrency_limit = 8,api_name = "chat_examples")
        .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
        .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs],[tab_sources, tab_figures, tab_recommended_content] )

    )


    def change_sample_questions(key):
        index = list(QUESTIONS.keys()).index(key)
        visible_bools = [False] * len(samples)
        visible_bools[index] = True
        return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]



    dropdown_samples.change(change_sample_questions,dropdown_samples,samples)


    demo.queue()

demo.launch(ssr_mode=False)