MohamedLamineBamba commited on
Commit
a3b1498
·
1 Parent(s): 0dfba83

Perf(Parent Document Retriever): persist docs and vectorstore using LocalFileStore, update prompt, and refactor code

Browse files
Files changed (5) hide show
  1. app.py +30 -25
  2. config.py +1 -1
  3. prompts.py +19 -3
  4. scrape_data.py +23 -16
  5. utils.py +7 -1
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import chainlit as cl
2
  from langchain.retrievers import ParentDocumentRetriever
3
- from langchain.schema import StrOutputParser
4
- from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
5
- from langchain.storage import InMemoryStore
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
@@ -24,39 +24,36 @@ model = GoogleGenerativeAI(
24
  },
25
  ) # type: ignore
26
 
27
- # Load vector database that was persisted earlier
28
- embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
29
- model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
30
  ) # type: ignore
31
 
32
- vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
33
 
34
  ## retriever
35
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
 
 
 
 
 
 
 
36
 
37
  # The storage layer for the parent documents
38
- store = InMemoryStore()
 
 
39
  retriever = ParentDocumentRetriever(
40
- vectorstore=vectordb,
41
  docstore=store,
42
- child_splitter=text_splitter,
43
  )
44
 
45
 
46
  @cl.on_chat_start
47
  async def on_chat_start():
48
 
49
- rag_chain = (
50
- {
51
- "context": retriever | format_docs,
52
- "question": RunnablePassthrough(),
53
- }
54
- | prompt
55
- | model
56
- | StrOutputParser()
57
- )
58
-
59
- cl.user_session.set("rag_chain", rag_chain)
60
 
61
  msg = cl.Message(
62
  content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
@@ -66,12 +63,20 @@ async def on_chat_start():
66
 
67
  @cl.on_message
68
  async def on_message(message: cl.Message):
69
- runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
 
 
 
 
70
  msg = cl.Message(content="")
71
 
72
  async with cl.Step(type="run", name="QA Assistant"):
73
- async for chunk in runnable.astream(
74
- message.content,
 
 
 
 
75
  config=RunnableConfig(
76
  callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
77
  ),
 
1
  import chainlit as cl
2
  from langchain.retrievers import ParentDocumentRetriever
3
+ from langchain.schema.runnable import RunnableConfig
4
+ from langchain.storage import LocalFileStore
5
+ from langchain.storage._lc_store import create_kv_docstore
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_google_genai import (
 
24
  },
25
  ) # type: ignore
26
 
27
+ embeddings_model = GoogleGenerativeAIEmbeddings(
28
+ model=config.GOOGLE_EMBEDDING_MODEL
 
29
  ) # type: ignore
30
 
 
31
 
32
  ## retriever
33
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
34
+
35
+ # The vectorstore to use to index the child chunks
36
+ vectorstore = Chroma(
37
+ persist_directory=config.STORAGE_PATH + "vectorstore",
38
+ collection_name="full_documents",
39
+ embedding_function=embeddings_model,
40
+ )
41
 
42
  # The storage layer for the parent documents
43
+ fs = LocalFileStore(config.STORAGE_PATH + "docstore")
44
+ store = create_kv_docstore(fs)
45
+
46
  retriever = ParentDocumentRetriever(
47
+ vectorstore=vectorstore,
48
  docstore=store,
49
+ child_splitter=child_splitter,
50
  )
51
 
52
 
53
  @cl.on_chat_start
54
  async def on_chat_start():
55
 
56
+ cl.user_session.set("retriever", retriever)
 
 
 
 
 
 
 
 
 
 
57
 
58
  msg = cl.Message(
59
  content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
 
63
 
64
  @cl.on_message
65
  async def on_message(message: cl.Message):
66
+
67
+ # retriever = cl.user_session.get("retriever")
68
+
69
+ chain = prompt | model
70
+
71
  msg = cl.Message(content="")
72
 
73
  async with cl.Step(type="run", name="QA Assistant"):
74
+
75
+ question = message.content
76
+ context = format_docs(retriever.get_relevant_documents(question))
77
+
78
+ async for chunk in chain.astream(
79
+ input={"context": context, "question": question},
80
  config=RunnableConfig(
81
  callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
82
  ),
config.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
4
  GOOGLE_CHAT_MODEL = "gemini-pro"
5
  GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
6
- STORAGE_PATH = "data/chroma/"
7
  HIISTORY_FILE = "./data/qa_history.txt"
8
 
9
  NUM_DAYS_PAST = 30
 
3
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
4
  GOOGLE_CHAT_MODEL = "gemini-pro"
5
  GOOGLE_EMBEDDING_MODEL = "models/embedding-001"
6
+ STORAGE_PATH = "./data/"
7
  HIISTORY_FILE = "./data/qa_history.txt"
8
 
9
  NUM_DAYS_PAST = 30
prompts.py CHANGED
@@ -1,11 +1,27 @@
1
  from langchain.prompts import ChatPromptTemplate
2
 
3
  template = """
4
- Répondez à la question en vous basant uniquement sur le contexte suivant:
 
 
5
 
6
- {context}
7
 
8
- Question : {question}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  """
11
 
 
1
  from langchain.prompts import ChatPromptTemplate
2
 
3
  template = """
4
+ Vous êtes un assistant de recherche économique et financière, spécialement conçu pour répondre aux questions liées à l'économie et à la finance et pour aider à l'informations et la prise de décisions financières. Votre rôle consiste à analyser les articles et rapports d'actualité économique et financière qui vous sera fournis dans le contexte et à répondre de manière adequate aux questions spécifiques des utilisateurs. Lorsque vous répondez aux questions :
5
+ - Pour des questions d'ordre générales (ex: "Quelle est l'actualité du jour?") : Lisez attentivement tous les articles et résumez les points\évènements clés en mentionnant les dates de publications.
6
+ - Pour des questions spécifiques (ex: "Quelle est la tendance du marché boursier aujourd'hui?") : Recherchez les informations spécifiques à la question dans les articles.
7
 
8
+ -N'hésitez pas à utiliser vos connaissances et votre bon sens pour répondre aux questions.
9
 
10
+ - Basez vos réponses sur les articles d'actualité fournis. Citez directement les parties pertinentes de ces documents pour étayer vos réponses.
11
+ - Citez clairement les références, y compris les titres des articles, les dates de publication et tout autre détail pertinent, afin de vous assurer que les informations peuvent être facilement vérifiées et retracées jusqu'aux sources originales.
12
+
13
+ - Si la question sort du cadre des documents fournis ou si vous ne trouvez pas d'informations pertinentes, indiquez poliment que la réponse ne peut être déterminée sur la base des sources disponibles. Suggérez de consulter d'autres articles d'actualité financière ou des bases de données pour obtenir une réponse complète, le cas échéant.
14
+ - Insistez sur l'exactitude et la fiabilité de vos réponses, en comprenant la nature critique de votre aide dans les processus de prise de décision financière.
15
+ - Répondez aux utilisateurs dans la langue de leur question. Si la question est en français, votre réponse doit être en français. Si la question est en anglais, votre réponse doit être en anglais.
16
+ - Pour des question en relative à la date veuillez considerer qu'aujourd'hui est le Jeudi 11/04/2024. Par exemple pour repondre à une question sur l'actualité du jour, vous devez effectuer une comparaison entre les date de publications des articles et celle d'aujourdui pour filtrer sur les articles puis retourner les informations pertinantes.
17
+
18
+ <contexte>
19
+ ``{context}``
20
+ </contexte>
21
+
22
+ <question>
23
+ {question}
24
+ </question>
25
 
26
  """
27
 
scrape_data.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  from datetime import date, timedelta
3
 
4
  import bs4
5
- from langchain.indexes import SQLRecordManager, index
 
 
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.vectorstores.chroma import Chroma
8
  from langchain_community.document_loaders import WebBaseLoader
@@ -81,7 +83,7 @@ def set_metadata(documents, metadatas):
81
 
82
 
83
  def process_docs(
84
- articles, persist_directory, embeddings_model, chunk_size=1000, chunk_overlap=100
85
  ):
86
  """
87
  #Scrap all articles urls content and save on a vector DB
@@ -105,28 +107,33 @@ def process_docs(
105
  # Update metadata: add title,
106
  set_metadata(documents=docs, metadatas=articles)
107
 
108
- print("Successfully loaded to document")
109
 
110
- text_splitter = RecursiveCharacterTextSplitter(
111
- chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"]
 
 
 
 
 
 
112
  )
113
- splits = text_splitter.split_documents(docs)
114
 
115
- # Create the storage path if it doesn't exist
116
- if not os.path.exists(persist_directory):
117
- os.makedirs(persist_directory)
118
 
119
- doc_search = Chroma.from_documents(
120
- documents=splits,
121
- embedding=embeddings_model,
122
- persist_directory=persist_directory,
123
  )
124
 
125
- return doc_search
 
126
 
127
 
128
  if __name__ == "__main__":
129
 
130
  data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
131
- vectordb = process_docs(data, config.STORAGE_PATH, embeddings_model)
132
- ret = vectordb.as_retriever()
 
2
  from datetime import date, timedelta
3
 
4
  import bs4
5
+ from langchain.retrievers import ParentDocumentRetriever
6
+ from langchain.storage import LocalFileStore
7
+ from langchain.storage._lc_store import create_kv_docstore
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.vectorstores.chroma import Chroma
10
  from langchain_community.document_loaders import WebBaseLoader
 
83
 
84
 
85
  def process_docs(
86
+ articles, persist_directory, embeddings_model, chunk_size=500, chunk_overlap=0
87
  ):
88
  """
89
  #Scrap all articles urls content and save on a vector DB
 
107
  # Update metadata: add title,
108
  set_metadata(documents=docs, metadatas=articles)
109
 
110
+ # print("Successfully loaded to document")
111
 
112
+ # This text splitter is used to create the child documents
113
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n"])
114
+
115
+ # The vectorstore to use to index the child chunks
116
+ vectorstore = Chroma(
117
+ persist_directory=persist_directory + "vectorstore",
118
+ collection_name="full_documents",
119
+ embedding_function=embeddings_model,
120
  )
 
121
 
122
+ # The storage layer for the parent documents
123
+ fs = LocalFileStore(persist_directory + "docstore")
124
+ store = create_kv_docstore(fs)
125
 
126
+ retriever = ParentDocumentRetriever(
127
+ vectorstore=vectorstore,
128
+ docstore=store,
129
+ child_splitter=child_splitter,
130
  )
131
 
132
+ retriever.add_documents(docs, ids=None)
133
+ print(len(docs), " documents added")
134
 
135
 
136
  if __name__ == "__main__":
137
 
138
  data = scrap_articles(DATA_URL, num_days_past=config.NUM_DAYS_PAST)
139
+ process_docs(data, config.STORAGE_PATH, embeddings_model)
 
utils.py CHANGED
@@ -11,8 +11,9 @@ def format_docs(documents, max_context_size=100000, separator="\n\n"):
11
  i += 1
12
  if len(encoder.encode(context)) < max_context_size:
13
  source = doc.metadata["link"]
 
14
  context += (
15
- f"Article{i}:\n" + doc.page_content + f"\nSource: {source}" + separator
16
  )
17
  return context
18
 
@@ -43,3 +44,8 @@ class PostMessageHandler(BaseCallbackHandler):
43
  source_names = [el.name for el in sources_element]
44
  self.msg.elements += sources_element
45
  self.msg.content += f"\nSources: {', '.join(source_names)}"
 
 
 
 
 
 
11
  i += 1
12
  if len(encoder.encode(context)) < max_context_size:
13
  source = doc.metadata["link"]
14
+ title = doc.metadata["title"]
15
  context += (
16
+ f"Article: {title}\n" + doc.page_content + f"\nSource: {source}" + separator
17
  )
18
  return context
19
 
 
44
  source_names = [el.name for el in sources_element]
45
  self.msg.elements += sources_element
46
  self.msg.content += f"\nSources: {', '.join(source_names)}"
47
+
48
+ def clean_text(text):
49
+ tx = text.replace("Tweet","")
50
+ tx = tx.replace("\n\n\n\n\n\n\n\n\n","")
51
+ return tx