from climateqa.engine.embeddings import get_embeddings_function
embeddings_function = get_embeddings_function()

from climateqa.papers.openalex import OpenAlex
from sentence_transformers import CrossEncoder

reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
oa = OpenAlex()

import gradio as gr
import pandas as pd
import numpy as np
import os
import time
import re
import json

# from gradio_modal import Modal

from io import BytesIO
import base64

from datetime import datetime
from azure.storage.fileshare import ShareServiceClient

from utils import create_user_id



# ClimateQ&A imports
from climateqa.engine.llm import get_llm
from climateqa.engine.chains.answer_rag import make_rag_chain
from climateqa.engine.vectorstore import get_pinecone_vectorstore
from climateqa.engine.retriever import ClimateQARetriever
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.chains.prompts import audience_prompts
from climateqa.sample_questions import QUESTIONS
from climateqa.constants import POSSIBLE_REPORTS
from climateqa.utils import get_image_from_azure_blob_storage
from climateqa.engine.keywords import make_keywords_chain
from climateqa.engine.chains.answer_rag import make_rag_papers_chain

# Load environment variables in local mode
try:
    from dotenv import load_dotenv
    load_dotenv()
except Exception as e:
    pass

# Set up Gradio Theme
theme = gr.themes.Base(
    primary_hue="blue",
    secondary_hue="red",
    font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)



init_prompt = ""

system_template = {
    "role": "system",
    "content": init_prompt,
}

account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
    account_key += "=="

credential = {
    "account_key": account_key,
    "account_name": os.environ["BLOB_ACCOUNT_NAME"],
}

account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)

user_id = create_user_id()



def parse_output_llm_with_sources(output):
    # Split the content into a list of text and "[Doc X]" references
    content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
    parts = []
    for part in content_parts:
        if part.startswith("Doc"):
            subparts = part.split(",")
            subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
            subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
            parts.append("".join(subparts))
        else:
            parts.append(part)
    content_parts = "".join(parts)
    return content_parts


# Create vectorstore and retriever
vectorstore = get_pinecone_vectorstore(embeddings_function)
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)


def make_pairs(lst):
    """from a list of even lenght, make tupple pairs"""
    return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)]


def serialize_docs(docs):
    new_docs = []
    for doc in docs:
        new_doc = {}
        new_doc["page_content"] = doc.page_content
        new_doc["metadata"] = doc.metadata
        new_docs.append(new_doc)
    return new_docs



async def chat(query,history,audience,sources,reports):
    """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
    (messages in gradio format, messages in langchain format, source documents)"""

    print(f">> NEW QUESTION : {query}")

    if audience == "Children":
        audience_prompt = audience_prompts["children"]
    elif audience == "General public":
        audience_prompt = audience_prompts["general"]
    elif audience == "Experts":
        audience_prompt = audience_prompts["experts"]
    else:
        audience_prompt = audience_prompts["experts"]

    # Prepare default values
    if len(sources) == 0:
        sources = ["IPCC"]

    if len(reports) == 0:
        reports = []

    retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
    rag_chain = make_rag_chain(retriever,llm)
    
    inputs = {"query": query,"audience": audience_prompt}
    result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
    # result = rag_chain.stream(inputs)

    path_reformulation = "/logs/reformulation/final_output"
    path_keywords = "/logs/keywords/final_output"
    path_retriever = "/logs/find_documents/final_output"
    path_answer = "/logs/answer/streamed_output_str/-"

    docs_html = ""
    output_query = ""
    output_language = ""
    output_keywords = ""
    gallery = []

    try:
        async for op in result:

            op = op.ops[0]

            if op['path'] == path_reformulation: # reforulated question
                try:
                    output_language = op['value']["language"] # str
                    output_query = op["value"]["question"]
                except Exception as e:
                    raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
            
            if op["path"] == path_keywords:
                try:
                    output_keywords = op['value']["keywords"] # str
                    output_keywords = " AND ".join(output_keywords)
                except Exception as e:
                    pass
            

            elif op['path'] == path_retriever: # documents
                try:
                    docs = op['value']['docs'] # List[Document]
                    docs_html = []
                    for i, d in enumerate(docs, 1):
                        docs_html.append(make_html_source(d, i))
                    docs_html = "".join(docs_html)
                except TypeError:
                    print("No documents found")
                    print("op: ",op)
                    continue

            elif op['path'] == path_answer: # final answer
                new_token = op['value'] # str
                # time.sleep(0.01)
                previous_answer = history[-1][1]
                previous_answer = previous_answer if previous_answer is not None else ""
                answer_yet = previous_answer + new_token
                answer_yet = parse_output_llm_with_sources(answer_yet)
                history[-1] = (query,answer_yet)



            else:
                continue

            history = [tuple(x) for x in history]
            yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords

    except Exception as e:
        raise gr.Error(f"{e}")


    try:
        # Log answer on Azure Blob Storage
        if os.getenv("GRADIO_ENV") != "local":
            timestamp = str(datetime.now().timestamp())
            file = timestamp + ".json"
            prompt = history[-1][0]
            logs = {
                "user_id": str(user_id),
                "prompt": prompt,
                "query": prompt,
                "question":output_query,
                "sources":sources,
                "docs":serialize_docs(docs),
                "answer": history[-1][1],
                "time": timestamp,
            }
            log_on_azure(file, logs, share_client)
    except Exception as e:
        print(f"Error logging on Azure Blob Storage: {e}")
        raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")

    image_dict = {}
    for i,doc in enumerate(docs):
        
        if doc.metadata["chunk_type"] == "image":
            try:
                key = f"Image {i+1}"
                image_path = doc.metadata["image_path"].split("documents/")[1]
                img = get_image_from_azure_blob_storage(image_path)

                # Convert the image to a byte buffer
                buffered = BytesIO()
                img.save(buffered, format="PNG")
                img_str = base64.b64encode(buffered.getvalue()).decode()

                # Embedding the base64 string in Markdown
                markdown_image = f"![Alt text](data:image/png;base64,{img_str})"
                image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]}
            except Exception as e:
                print(f"Skipped adding image {i} because of {e}")

    if len(image_dict) > 0:

        gallery = [x["img"] for x in list(image_dict.values())]
        img = list(image_dict.values())[0]
        img_md = img["md"]
        img_caption = img["caption"]
        img_code = img["figure_code"]
        if img_code != "N/A":
            img_name = f"{img['key']} - {img['figure_code']}"
        else:
            img_name = f"{img['key']}"

        answer_yet = history[-1][1] + f"\n\n{img_md}\n<p class='chatbot-caption'><b>{img_name}</b> - {img_caption}</p>"
        history[-1] = (history[-1][0],answer_yet)
        history = [tuple(x) for x in history]

    # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
    # if len(gallery) > 0:
    #     gallery = list(set("|".join(gallery).split("|")))
    #     gallery = [get_image_from_azure_blob_storage(x) for x in gallery]

    yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords


def make_html_source(source,i):
    meta = source.metadata
    # content = source.page_content.split(":",1)[1].strip()
    content = source.page_content.strip()

    toc_levels = []
    for j in range(2):
        level = meta[f"toc_level{j}"]
        if level != "N/A":
            toc_levels.append(level)
        else:
            break
    toc_levels = " > ".join(toc_levels)

    if len(toc_levels) > 0:
        name = f"<b>{toc_levels}</b><br/>{meta['name']}"
    else:
        name = meta['name']

    if meta["chunk_type"] == "text":

        card = f"""
    <div class="card" id="doc{i}">
        <div class="card-content">
            <h2>Doc {i} - {meta['short_name']} - Page {int(meta['page_number'])}</h2>
            <p>{content}</p>
        </div>
        <div class="card-footer">
            <span>{name}</span>
            <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
                <span role="img" aria-label="Open PDF">🔗</span>
            </a>
        </div>
    </div>
    """
    
    else:

        if meta["figure_code"] != "N/A":
            title = f"{meta['figure_code']} - {meta['short_name']}"
        else:
            title = f"{meta['short_name']}"

        card = f"""
    <div class="card card-image">
        <div class="card-content">
            <h2>Image {i} - {title} - Page {int(meta['page_number'])}</h2>
            <p>{content}</p>
            <p class='ai-generated'>AI-generated description</p>
        </div>
        <div class="card-footer">
            <span>{name}</span>
            <a href="{meta['url']}#page={int(meta['page_number'])}" target="_blank" class="pdf-link">
                <span role="img" aria-label="Open PDF">🔗</span>
            </a>
        </div>
    </div>
    """
        
    return card



#     else:
#         docs_string = "No relevant passages found in the climate science reports (IPCC and IPBES)"
#         complete_response = "**No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate issues).**"
#         messages.append({"role": "assistant", "content": complete_response})
#         gradio_format = make_pairs([a["content"] for a in messages[1:]])
#         yield gradio_format, messages, docs_string


def save_feedback(feed: str, user_id):
    if len(feed) > 1:
        timestamp = str(datetime.now().timestamp())
        file = user_id + timestamp + ".json"
        logs = {
            "user_id": user_id,
            "feedback": feed,
            "time": timestamp,
        }
        log_on_azure(file, logs, share_client)
        return "Feedback submitted, thank you!"




def log_on_azure(file, logs, share_client):
    logs = json.dumps(logs)
    file_client = share_client.get_file_client(file)
    file_client.upload_file(logs)


def generate_keywords(query):
    chain = make_keywords_chain(llm)
    keywords = chain.invoke(query)
    keywords = " AND ".join(keywords["keywords"])
    return keywords



papers_cols_widths = {
    "doc":50,
    "id":100,
    "title":300,
    "doi":100,
    "publication_year":100,
    "abstract":500,
    "rerank_score":100,
    "is_oa":50,
}

papers_cols = list(papers_cols_widths.keys())
papers_cols_widths = list(papers_cols_widths.values())

async def find_papers(query, keywords,after):

    summary = ""
    
    df_works = oa.search(keywords,after = after)
    df_works = df_works.dropna(subset=["abstract"])
    df_works = oa.rerank(query,df_works,reranker)
    df_works = df_works.sort_values("rerank_score",ascending=False)
    G = oa.make_network(df_works)

    height = "750px"
    network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
    network_html = network.generate_html()

    network_html = network_html.replace("'", "\"")
    css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
    network_html = network_html + css_to_inject

    
    network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""


    docs = df_works["content"].head(15).tolist()

    df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
    df_works["doc"] = df_works["doc"] + 1
    df_works = df_works[papers_cols]

    yield df_works,network_html,summary

    chain = make_rag_papers_chain(llm)
    result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
    path_answer = "/logs/StrOutputParser/streamed_output/-"

    async for op in result:

        op = op.ops[0]

        if op['path'] == path_answer: # reforulated question
            new_token = op['value'] # str
            summary += new_token
        else:
            continue
        yield df_works,network_html,summary
    


# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------


init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.

❓ How to use
- **Language**: You can ask me your questions in any language. 
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.

⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*

What do you want to learn ?
"""


def vote(data: gr.LikeData):
    if data.liked:
        print(data.value)
    else:
        print(data)



with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
    # user_id_state = gr.State([user_id])

    with gr.Tab("ClimateQ&A"):

        with gr.Row(elem_id="chatbot-row"):
            with gr.Column(scale=2):
                # state = gr.State([system_template])
                chatbot = gr.Chatbot(
                    value=[(None,init_prompt)],
                    show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
                    avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
                )#,avatar_images = ("assets/logo4.png",None))
                
                # bot.like(vote,None,None)



                with gr.Row(elem_id = "input-message"):
                    textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
                    # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png")


            with gr.Column(scale=1, variant="panel",elem_id = "right-panel"):


                with gr.Tabs() as tabs:
                    with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
                                        
                        examples_hidden = gr.Textbox(visible = False)
                        first_key = list(QUESTIONS.keys())[0]
                        dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")

                        samples = []
                        for i,key in enumerate(QUESTIONS.keys()):

                            examples_visible = True if i == 0 else False

                            with gr.Row(visible = examples_visible) as group_examples:

                                examples_questions = gr.Examples(
                                    QUESTIONS[key],
                                    [examples_hidden],
                                    examples_per_page=8,
                                    run_on_click=False,
                                    elem_id=f"examples{i}",
                                    api_name=f"examples{i}",
                                    # label = "Click on the example question or enter your own",
                                    # cache_examples=True,
                                )
                            
                            samples.append(group_examples)


                    with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
                        sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
                        docs_textbox = gr.State("")

                    # with Modal(visible = False) as config_modal:
                    with gr.Tab("Configuration",elem_id = "tab-config",id = 2):

                        gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")


                        dropdown_sources = gr.CheckboxGroup(
                            ["IPCC", "IPBES","IPOS"],
                            label="Select source",
                            value=["IPCC"],
                            interactive=True,
                        )

                        dropdown_reports = gr.Dropdown(
                            POSSIBLE_REPORTS,
                            label="Or select specific reports",
                            multiselect=True,
                            value=None,
                            interactive=True,
                        )

                        dropdown_audience = gr.Dropdown(
                            ["Children","General public","Experts"],
                            label="Select audience",
                            value="Experts",
                            interactive=True,
                        )

                        output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
                        output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)






#---------------------------------------------------------------------------------------
# OTHER TABS
#---------------------------------------------------------------------------------------


    with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
        gallery_component = gr.Gallery()

    with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):

        with gr.Row():
            with gr.Column(scale=1):
                query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
                keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
                after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
                search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)

            with gr.Column(scale=7):

                with gr.Tab("Summary",elem_id="papers-summary-tab"):
                    papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")

                with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
                    papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)

                with gr.Tab("Citations network",elem_id="papers-network-tab"):
                    citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")


            
    with gr.Tab("About",elem_classes = "max-height other-tabs"):
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)")


    def start_chat(query,history):
        history = history + [(query,None)]
        history = [tuple(x) for x in history]
        return (gr.update(interactive = False),gr.update(selected=1),history)
    
    def finish_chat():
        return (gr.update(interactive = True,value = ""))

    (textbox
        .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
        .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
        .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
    )

    (examples_hidden
        .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
        .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
        .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
    )


    def change_sample_questions(key):
        index = list(QUESTIONS.keys()).index(key)
        visible_bools = [False] * len(samples)
        visible_bools[index] = True
        return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]



    dropdown_samples.change(change_sample_questions,dropdown_samples,samples)

    query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
    search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])

    # # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
    # (textbox
    #     .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
    #     .success(change_tab,None,tabs)
    #     .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
    #     .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
    #     .success(lambda x : textbox,[textbox],[textbox])
    # )

    # (examples_hidden
    #     .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
    #     .success(change_tab,None,tabs)
    #     .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
    #     .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
    #     .success(lambda x : textbox,[textbox],[textbox])
    # )
    # submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
    #         answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
    #     )


    # with Modal(visible=True) as first_modal:
    #     gr.Markdown("# Welcome to ClimateQ&A !")

    #     gr.Markdown("### Examples")

    #     examples = gr.Examples(
    #         ["Yo ça roule","ça boume"],
    #         [examples_hidden],
    #         examples_per_page=8,
    #         run_on_click=False,
    #         elem_id="examples",
    #         api_name="examples",
    #     )


    # submit.click(lambda: Modal(visible=True), None, config_modal)
    

    demo.queue()

demo.launch()