Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.chains import create_retrieval_chain,create_history_aware_retriever | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from langchain_google_genai import GoogleGenerativeAI | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
import os | |
os.environ["GOOGLE_API_KEY"] = os.getenv("HF_GEMINI_SECRET_KEY") | |
# 1. Chargement des embeddings et du modéle | |
gemini_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
llm = GoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.2) | |
# 2. Chargement des documents | |
loader_data = PyPDFDirectoryLoader('./data') | |
docs = loader_data.load() | |
# 3. Découpage des documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
texts = text_splitter.split_documents(docs) | |
# 4. Création de la base vectorielle et du retriever | |
vectorstore = FAISS.from_documents(documents=texts, embedding=gemini_embeddings) | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 2}) | |
### Création du retriever conscient de l'historique | |
contextualize_q_system_prompt = ( | |
"Compte tenu de l'historique de la conversation et de la dernière question " | |
"de l'utilisateur, reformulez une question autonome compréhensible sans l'historique. " | |
"Ne répondez pas à la question, reformulez-la uniquement si nécessaire." | |
) | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
history_aware_retriever = create_history_aware_retriever( | |
llm, retriever, contextualize_q_prompt | |
) | |
### Chaîne de question-réponse avec historique | |
system_prompt = ( | |
"Tu es un assistant pour les tâches de question-réponse appliquée à la santé. " | |
"Aide toi du contexte si nécessaire : {context} " | |
"Si le contexte ne fourni pas assez d'élément mais que tu connais la réponse alors donne la" | |
"Si le contexte n'est pas assez fourni et que tu ne sais pas, dis-le explicitement " | |
"que tu ne sais pas." | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
# Fonction pour l'interface | |
def respond(message, history): | |
chat_history = [] | |
for msg in history: | |
if msg['role'] == "user": | |
chat_history.append(HumanMessage(content=msg['content'])) | |
elif msg['role'] == "assistant": | |
chat_history.append(AIMessage(content=msg['content'])) | |
chat_history.append(HumanMessage(content=message)) | |
# Invocation de la chaîne avec l'historique | |
print(f"historique : {chat_history}") | |
ai_message = rag_chain.invoke({ | |
"input": message, | |
"chat_history": chat_history | |
}) | |
return ai_message["answer"] | |
# Interface Gradio avec ChatInterface | |
demo = gr.ChatInterface( | |
respond, | |
type="messages", | |
title="Assistant Médical AI", | |
description="Posez vos questions médicales, je vous répondrai en me basant sur les documents fournis.", | |
examples=["C'est quoi le SAMU?", "Comment fonctionne un service d'urgence?"] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |