MohamedLamineBamba
commited on
Commit
·
a3b1498
1
Parent(s):
0dfba83
Perf(Parent Document Retriever): persist docs and vectorstore using LocalFileStore, update prompt, and refactor code
Browse files- app.py +30 -25
- config.py +1 -1
- prompts.py +19 -3
- scrape_data.py +23 -16
- utils.py +7 -1
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import chainlit as cl
|
2 |
from langchain.retrievers import ParentDocumentRetriever
|
3 |
-
from langchain.schema import
|
4 |
-
from langchain.
|
5 |
-
from langchain.storage import
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain.vectorstores.chroma import Chroma
|
8 |
from langchain_google_genai import (
|
@@ -24,39 +24,36 @@ model = GoogleGenerativeAI(
|
|
24 |
},
|
25 |
) # type: ignore
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
|
30 |
) # type: ignore
|
31 |
|
32 |
-
vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
|
33 |
|
34 |
## retriever
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# The storage layer for the parent documents
|
38 |
-
|
|
|
|
|
39 |
retriever = ParentDocumentRetriever(
|
40 |
-
vectorstore=
|
41 |
docstore=store,
|
42 |
-
child_splitter=
|
43 |
)
|
44 |
|
45 |
|
46 |
@cl.on_chat_start
|
47 |
async def on_chat_start():
|
48 |
|
49 |
-
|
50 |
-
{
|
51 |
-
"context": retriever | format_docs,
|
52 |
-
"question": RunnablePassthrough(),
|
53 |
-
}
|
54 |
-
| prompt
|
55 |
-
| model
|
56 |
-
| StrOutputParser()
|
57 |
-
)
|
58 |
-
|
59 |
-
cl.user_session.set("rag_chain", rag_chain)
|
60 |
|
61 |
msg = cl.Message(
|
62 |
content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
|
@@ -66,12 +63,20 @@ async def on_chat_start():
|
|
66 |
|
67 |
@cl.on_message
|
68 |
async def on_message(message: cl.Message):
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
msg = cl.Message(content="")
|
71 |
|
72 |
async with cl.Step(type="run", name="QA Assistant"):
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
config=RunnableConfig(
|
76 |
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
|
77 |
),
|
|
|
1 |
import chainlit as cl
|
2 |
from langchain.retrievers import ParentDocumentRetriever
|
3 |
+
from langchain.schema.runnable import RunnableConfig
|
4 |
+
from langchain.storage import LocalFileStore
|
5 |
+
from langchain.storage._lc_store import create_kv_docstore
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain.vectorstores.chroma import Chroma
|
8 |
from langchain_google_genai import (
|
|
|
24 |
},
|
25 |
) # type: ignore
|
26 |
|
27 |
+
embeddings_model = GoogleGenerativeAIEmbeddings(
|
28 |
+
model=config.GOOGLE_EMBEDDING_MODEL
|
|
|
29 |
) # type: ignore
|
30 |
|
|
|
31 |
|
32 |
## retriever
|
33 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
|
34 |
+
|
35 |
+
# The vectorstore to use to index the child chunks
|
36 |
+
vectorstore = Chroma(
|
37 |
+
persist_directory=config.STORAGE_PATH + "vectorstore",
|
38 |
+
collection_name="full_documents",
|
39 |
+
embedding_function=embeddings_model,
|
40 |
+
)
|
41 |
|
42 |
# The storage layer for the parent documents
|
43 |
+
fs = LocalFileStore(config.STORAGE_PATH + "docstore")
|
44 |
+
store = create_kv_docstore(fs)
|
45 |
+
|
46 |
retriever = ParentDocumentRetriever(
|
47 |
+
vectorstore=vectorstore,
|
48 |
docstore=store,
|
49 |
+
child_splitter=child_splitter,
|
50 |
)
|
51 |
|
52 |
|
53 |
@cl.on_chat_start
|
54 |
async def on_chat_start():
|
55 |
|
56 |
+
cl.user_session.set("retriever", retriever)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
msg = cl.Message(
|
59 |
content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
|
|
|
63 |
|
64 |
@cl.on_message
|
65 |
async def on_message(message: cl.Message):
|
66 |
+
|
67 |
+
# retriever = cl.user_session.get("retriever")
|
68 |
+
|
69 |
+
chain = prompt | model
|
70 |
+
|
71 |
msg = cl.Message(content="")
|
72 |
|
73 |
async with cl.Step(type="run", name="QA Assistant"):
|
74 |
+
|
75 |
+
question = message.content
|
76 |
+
context = format_docs(retriever.get_relevant_documents(question))
|
77 |
+
|
78 |
+
async for chunk in chain.astream(
|
79 |
+
input={"context": context, "question": question},
|
80 |
config=RunnableConfig(
|
81 |
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
|
82 |
),
|
config.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
4 |
GOOGLE_CHAT_MODEL = "gemini-pro"
|
5 |
GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
|
6 |
-
STORAGE_PATH = "data/
|
7 |
HIISTORY_FILE = "./data/qa_history.txt"
|
8 |
|
9 |
NUM_DAYS_PAST = 30
|
|
|
3 |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
4 |
GOOGLE_CHAT_MODEL = "gemini-pro"
|
5 |
GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
|
6 |
+
STORAGE_PATH = "./data/"
|
7 |
HIISTORY_FILE = "./data/qa_history.txt"
|
8 |
|
9 |
NUM_DAYS_PAST = 30
|
prompts.py
CHANGED
@@ -1,11 +1,27 @@
|
|
1 |
from langchain.prompts import ChatPromptTemplate
|
2 |
|
3 |
template = """
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
"""
|
11 |
|
|
|
1 |
from langchain.prompts import ChatPromptTemplate
|
2 |
|
3 |
template = """
|
4 |
+
Vous êtes un assistant de recherche économique et financière, spécialement conçu pour répondre aux questions liées à l'économie et à la finance et pour aider à l'informations et la prise de décisions financières. Votre rôle consiste à analyser les articles et rapports d'actualité économique et financière qui vous sera fournis dans le contexte et à répondre de manière adequate aux questions spécifiques des utilisateurs. Lorsque vous répondez aux questions :
|
5 |
+
- Pour des questions d'ordre générales (ex: "Quelle est l'actualité du jour?") : Lisez attentivement tous les articles et résumez les points\évènements clés en mentionnant les dates de publications.
|
6 |
+
- Pour des questions spécifiques (ex: "Quelle est la tendance du marché boursier aujourd'hui?") : Recherchez les informations spécifiques à la question dans les articles.
|
7 |
|
8 |
+
-N'hésitez pas à utiliser vos connaissances et votre bon sens pour répondre aux questions.
|
9 |
|
10 |
+
- Basez vos réponses sur les articles d'actualité fournis. Citez directement les parties pertinentes de ces documents pour étayer vos réponses.
|
11 |
+
- Citez clairement les références, y compris les titres des articles, les dates de publication et tout autre détail pertinent, afin de vous assurer que les informations peuvent être facilement vérifiées et retracées jusqu'aux sources originales.
|
12 |
+
|
13 |
+
- Si la question sort du cadre des documents fournis ou si vous ne trouvez pas d'informations pertinentes, indiquez poliment que la réponse ne peut être déterminée sur la base des sources disponibles. Suggérez de consulter d'autres articles d'actualité financière ou des bases de données pour obtenir une réponse complète, le cas échéant.
|
14 |
+
- Insistez sur l'exactitude et la fiabilité de vos réponses, en comprenant la nature critique de votre aide dans les processus de prise de décision financière.
|
15 |
+
- Répondez aux utilisateurs dans la langue de leur question. Si la question est en français, votre réponse doit être en français. Si la question est en anglais, votre réponse doit être en anglais.
|
16 |
+
- Pour des question en relative à la date veuillez considerer qu'aujourd'hui est le Jeudi 11/04/2024. Par exemple pour repondre à une question sur l'actualité du jour, vous devez effectuer une comparaison entre les date de publications des articles et celle d'aujourdui pour filtrer sur les articles puis retourner les informations pertinantes.
|
17 |
+
|
18 |
+
<contexte>
|
19 |
+
``{context}``
|
20 |
+
</contexte>
|
21 |
+
|
22 |
+
<question>
|
23 |
+
{question}
|
24 |
+
</question>
|
25 |
|
26 |
"""
|
27 |
|
scrape_data.py
CHANGED
@@ -2,7 +2,9 @@ import os
|
|
2 |
from datetime import date, timedelta
|
3 |
|
4 |
import bs4
|
5 |
-
from langchain.
|
|
|
|
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain.vectorstores.chroma import Chroma
|
8 |
from langchain_community.document_loaders import WebBaseLoader
|
@@ -81,7 +83,7 @@ def set_metadata(documents, metadatas):
|
|
81 |
|
82 |
|
83 |
def process_docs(
|
84 |
-
articles, persist_directory, embeddings_model, chunk_size=
|
85 |
):
|
86 |
"""
|
87 |
#Scrap all articles urls content and save on a vector DB
|
@@ -105,28 +107,33 @@ def process_docs(
|
|
105 |
# Update metadata: add title,
|
106 |
set_metadata(documents=docs, metadatas=articles)
|
107 |
|
108 |
-
print("Successfully loaded to document")
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
)
|
113 |
-
splits = text_splitter.split_documents(docs)
|
114 |
|
115 |
-
#
|
116 |
-
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
)
|
124 |
|
125 |
-
|
|
|
126 |
|
127 |
|
128 |
if __name__ == "__main__":
|
129 |
|
130 |
data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
|
131 |
-
|
132 |
-
ret = vectordb.as_retriever()
|
|
|
2 |
from datetime import date, timedelta
|
3 |
|
4 |
import bs4
|
5 |
+
from langchain.retrievers import ParentDocumentRetriever
|
6 |
+
from langchain.storage import LocalFileStore
|
7 |
+
from langchain.storage._lc_store import create_kv_docstore
|
8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
from langchain.vectorstores.chroma import Chroma
|
10 |
from langchain_community.document_loaders import WebBaseLoader
|
|
|
83 |
|
84 |
|
85 |
def process_docs(
|
86 |
+
articles, persist_directory, embeddings_model, chunk_size=500, chunk_overlap=0
|
87 |
):
|
88 |
"""
|
89 |
#Scrap all articles urls content and save on a vector DB
|
|
|
107 |
# Update metadata: add title,
|
108 |
set_metadata(documents=docs, metadatas=articles)
|
109 |
|
110 |
+
# print("Successfully loaded to document")
|
111 |
|
112 |
+
# This text splitter is used to create the child documents
|
113 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"])
|
114 |
+
|
115 |
+
# The vectorstore to use to index the child chunks
|
116 |
+
vectorstore = Chroma(
|
117 |
+
persist_directory=persist_directory + "vectorstore",
|
118 |
+
collection_name="full_documents",
|
119 |
+
embedding_function=embeddings_model,
|
120 |
)
|
|
|
121 |
|
122 |
+
# The storage layer for the parent documents
|
123 |
+
fs = LocalFileStore(persist_directory + "docstore")
|
124 |
+
store = create_kv_docstore(fs)
|
125 |
|
126 |
+
retriever = ParentDocumentRetriever(
|
127 |
+
vectorstore=vectorstore,
|
128 |
+
docstore=store,
|
129 |
+
child_splitter=child_splitter,
|
130 |
)
|
131 |
|
132 |
+
retriever.add_documents(docs, ids=None)
|
133 |
+
print(len(docs), " documents added")
|
134 |
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
|
138 |
data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
|
139 |
+
process_docs(data, config.STORAGE_PATH, embeddings_model)
|
|
utils.py
CHANGED
@@ -11,8 +11,9 @@ def format_docs(documents, max_context_size=100000, separator="\n\n"):
|
|
11 |
i += 1
|
12 |
if len(encoder.encode(context)) < max_context_size:
|
13 |
source = doc.metadata["link"]
|
|
|
14 |
context += (
|
15 |
-
f"Article{
|
16 |
)
|
17 |
return context
|
18 |
|
@@ -43,3 +44,8 @@ class PostMessageHandler(BaseCallbackHandler):
|
|
43 |
source_names = [el.name for el in sources_element]
|
44 |
self.msg.elements += sources_element
|
45 |
self.msg.content += f"\nSources: {', '.join(source_names)}"
|
|
|
|
|
|
|
|
|
|
|
|
11 |
i += 1
|
12 |
if len(encoder.encode(context)) < max_context_size:
|
13 |
source = doc.metadata["link"]
|
14 |
+
title = doc.metadata["title"]
|
15 |
context += (
|
16 |
+
f"Article: {title}\n" + doc.page_content + f"\nSource: {source}" + separator
|
17 |
)
|
18 |
return context
|
19 |
|
|
|
44 |
source_names = [el.name for el in sources_element]
|
45 |
self.msg.elements += sources_element
|
46 |
self.msg.content += f"\nSources: {', '.join(source_names)}"
|
47 |
+
|
48 |
+
def clean_text(text):
|
49 |
+
tx = text.replace("Tweet","")
|
50 |
+
tx = tx.replace("\n\n\n\n\n\n\n\n\n","")
|
51 |
+
return tx
|