# https://github.com/langchain-ai/langchain/issues/8623

import pandas as pd

from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from typing import List
from pydantic import Field

class ClimateQARetriever(BaseRetriever):
    vectorstore:VectorStore
    sources:list = ["IPCC","IPBES","IPOS"]
    reports:list = []
    threshold:float = 0.6
    k_summary:int = 3
    k_total:int = 10
    namespace:str = "vectors",
    min_size:int = 200,


    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:

        # Check if all elements in the list are either IPCC or IPBES
        assert isinstance(self.sources,list)
        assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
        assert self.k_total > self.k_summary, "k_total should be greater than k_summary"

        # Prepare base search kwargs
        filters = {}

        if len(self.reports) > 0:
            filters["short_name"] = {"$in":self.reports}
        else:
            filters["source"] = { "$in":self.sources}

        # Search for k_summary documents in the summaries dataset
        filters_summaries = {
            **filters,
            "report_type": { "$in":["SPM"]},
        }

        #build with pinecone
        #docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
        docs_summaries = self.vectorstore.similarity_search_with_score(query=query, k=self.k_summary)
        docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]

        # Search for k_total - k_summary documents in the full reports dataset
        filters_full = {
            **filters,
            "report_type": { "$nin":["SPM"]},
        }
        k_full = self.k_total - len(docs_summaries)
        #docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
        docs_full = self.vectorstore.similarity_search_with_score(query=query,k = k_full)

        # Concatenate documents
        docs = docs_summaries + docs_full

        # Filter if scores are below threshold
        docs = [x for x in docs if len(x[0].page_content) > self.min_size]
        # docs = [x for x in docs if x[1] > self.threshold]

        # Add score to metadata
        results = []
        for i,(doc,score) in enumerate(docs):
            doc.metadata["similarity_score"] = score
            doc.metadata["content"] = doc.page_content
            doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
            # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
            results.append(doc)
        
        # Sort by score
        # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)

        return results




# def filter_summaries(df,k_summary = 3,k_total = 10):
#     # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"

#     # # Filter by source
#     # if source == "IPCC":
#     #     df = df.loc[df["source"]=="IPCC"]
#     # elif source == "IPBES":
#     #     df = df.loc[df["source"]=="IPBES"]
#     # else:
#     #     pass

#     # Separate summaries and full reports
#     df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
#     df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]

#     # Find passages from summaries dataset
#     passages_summaries = df_summaries.head(k_summary)

#     # Find passages from full reports dataset
#     passages_fullreports = df_full.head(k_total - len(passages_summaries))

#     # Concatenate passages
#     passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
#     return passages




# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
#     assert max_k > k_total

#     validated_sources = ["IPCC","IPBES"]
#     sources = [x for x in sources if x in validated_sources]
#     filters = {
#         "source": { "$in": sources },
#     }
#     print(filters)

#     # Retrieve documents
#     docs = retriever.retrieve(query,top_k = max_k,filters = filters)

#     # Filter by score
#     docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]

#     if len(docs) == 0:
#         return []
#     res = pd.DataFrame(docs)
#     passages_df = filter_summaries(res,k_summary,k_total)
#     if as_dict:
#         contents = passages_df["content"].tolist()
#         meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
#         passages = []
#         for i in range(len(contents)):
#             passages.append({"content":contents[i],"meta":meta[i]})
#         return passages
#     else:
#         return passages_df



# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):


#     print("hellooooo")

#     # Reformulate queries
#     reformulated_query,language = reformulate(query)

#     print(reformulated_query)

#     # Retrieve documents
#     passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
#     response = {
#       "query":query,
#       "reformulated_query":reformulated_query,
#       "language":language,
#       "sources":passages,
#       "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
#     }
#     return response