amadoujr's picture
Update app.py
6eda1fd verified
import gradio as gr
from langchain.chains import create_retrieval_chain,create_history_aware_retriever
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai import GoogleGenerativeAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
import os
os.environ["GOOGLE_API_KEY"] = os.getenv("HF_GEMINI_SECRET_KEY")
# 1. Chargement des embeddings et du modéle
gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
llm = GoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.2)
# 2. Chargement des documents
loader_data = PyPDFDirectoryLoader('./data')
docs = loader_data.load()
# 3. Découpage des documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)
# 4. Création de la base vectorielle et du retriever
vectorstore = FAISS.from_documents(documents=texts, embedding=gemini_embeddings)
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 2})
### Création du retriever conscient de l'historique
contextualize_q_system_prompt = (
"Compte tenu de l'historique de la conversation et de la dernière question "
"de l'utilisateur, reformulez une question autonome compréhensible sans l'historique. "
"Ne répondez pas à la question, reformulez-la uniquement si nécessaire."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
### Chaîne de question-réponse avec historique
system_prompt = (
"Tu es un assistant pour les tâches de question-réponse appliquée à la santé. "
"Aide toi du contexte si nécessaire : {context} "
"Si le contexte ne fourni pas assez d'élément mais que tu connais la réponse alors donne la"
"Si le contexte n'est pas assez fourni et que tu ne sais pas, dis-le explicitement "
"que tu ne sais pas."
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Fonction pour l'interface
def respond(message, history):
chat_history = []
for msg in history:
if msg['role'] == "user":
chat_history.append(HumanMessage(content=msg['content']))
elif msg['role'] == "assistant":
chat_history.append(AIMessage(content=msg['content']))
chat_history.append(HumanMessage(content=message))
# Invocation de la chaîne avec l'historique
print(f"historique : {chat_history}")
ai_message = rag_chain.invoke({
"input": message,
"chat_history": chat_history
})
return ai_message["answer"]
# Interface Gradio avec ChatInterface
demo = gr.ChatInterface(
respond,
type="messages",
title="Assistant Médical AI",
description="Posez vos questions médicales, je vous répondrai en me basant sur les documents fournis.",
examples=["C'est quoi le SAMU?", "Comment fonctionne un service d'urgence?"]
)
if __name__ == "__main__":
demo.launch(share=True)