#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 22 15:43:16 2024

@author: Raphaël d'Assignies (rdassignies@protonmail.ch)
"""
import json
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from nodes import (GradeResults, GraphState, generate_query_node, 
                   generate_results_node, query_feedback_node, 
                   evaluate_query_node, evaluate_results_node)
import streamlit as st




# Instanciate pipeline
pipeline = StateGraph(GraphState)

pipeline.add_node('generate_query', generate_query_node)
pipeline.add_node('generate_results', generate_results_node)
pipeline.add_node('query_feedback', query_feedback_node)

# Only query
#pipeline.add_edge(START,'generate_query')
#pipeline.add_edge('generate_query', generate_query_node)
#pipeline.add_edge('generate_query', END)

# Full scenario
pipeline.add_edge(START,'generate_query')
pipeline.add_conditional_edges(
    'generate_query', 
    evaluate_query_node, 
    {'error_query' : 'generate_query',
     'ok' : 'generate_results'
     })

pipeline.add_conditional_edges(
    'generate_results', 
    evaluate_results_node,
    {
        "yes": END,
        "no": 'query_feedback',
        "max_generation_reached": END

    }  
)


# Création du graph
graph = pipeline.compile()

# Load le dataframe
df = pd.read_json('bodacc.json', orient='table')

# Initialise le dictionnaire
inputs = {
        'df_head': df.head().to_csv(), 
        'df': df
    }

# Créé un dictionnaire des sorties vide
outputs = {}


# Titre de l'application
st.title("Chat with BODACC !")

# Message d'avertissement
warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} "
                   f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.")

st.warning(warning_message)
# Interface utilisateur pour entrer la requête
user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours")


# Afficher les résultats avec Streamlit
inputs["instructions"] = user_query


# Afficher un bouton pour démarrer la recherche
if st.button("Lancer la recherche"):
    config = {"configurable": {"thread_id": "2"}}
    
    # Étape 1 : Afficher le message "Je réfléchis..."
    st.write("Je réfléchis...")

    # Stream des résultats au fur et à mesure
    with st.spinner('Recherche en cours...'):
        for output in graph.stream(inputs, stream_mode='values', debug=False):
            # Ajouter les résultats au dictionnaire outputs
            for k, v in output.items():
                if k not in outputs:
                    outputs[k] = []
                outputs[k].append(v)

            if "results" in output and len(output["results"]) > 0:
                records = json.loads(output['results'])
                st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.")

    # Après la fin du traitement
    if "results" in outputs and len(outputs["results"]) > 0:
        # Agréger tous les résultats accumulés
        all_results = []
        for res in outputs["results"]:
            json_data = json.loads(res)  # Convertir chaque ensemble de résultats en JSON
            all_results.extend(json_data)  # Accumuler tous les résultats

        results_df = pd.DataFrame(all_results)  # Créer un DataFrame avec tous les résultats accumulés
        # Afficher un aperçu des résultats (jusqu'à 5 premiers)
        num_results = len(results_df)
        st.write(f"J'ai trouvé {num_results} résultats.")
        if num_results > 0:
            preview_count = min(5, num_results)  # Gérer le cas où il y a moins de 5 résultats
            st.write(f"Voici un aperçu des {preview_count} premiers résultats :")
            st.write(results_df.head(preview_count))
            
            trunc = outputs.get('truncated', 'pas de traunc')
            
            if trunc[0] == True:
                st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ")
    
        # Convertir tous les résultats en CSV
        csv = results_df.to_csv(index=False)

        # Ajouter un bouton pour télécharger tous les résultats
        st.download_button(
            label="Télécharger le résultat complet au format CSV",
            data=csv,
            file_name="results.csv",
            mime="text/csv"
        )

    else:
        # Si aucun résultat n'est trouvé
        st.write("Aucun résultat trouvé.")