MohamedLamineBamba commited on
Commit
0c69aa1
·
1 Parent(s): a7aa9c3

feat: Parent Docuement Retriever

Browse files
Files changed (3) hide show
  1. app.py +29 -25
  2. requirements.txt +2 -1
  3. scrape_data.py +1 -18
app.py CHANGED
@@ -11,24 +11,12 @@ from langchain_google_genai import (
11
  HarmBlockThreshold,
12
  HarmCategory,
13
  )
14
-
 
 
15
  import config
16
  from prompts import prompt
17
-
18
- metadata_field_info = [
19
- AttributeInfo(
20
- name="title",
21
- description="Le titre de l'article",
22
- type="string",
23
- ),
24
- AttributeInfo(
25
- name="date",
26
- description="Date de publication",
27
- type="string",
28
- ),
29
- AttributeInfo(name="link", description="Source de l'article", type="string"),
30
- ]
31
- document_content_description = "Articles sur l'actualité."
32
 
33
  model = GoogleGenerativeAI(
34
  model=config.GOOGLE_CHAT_MODEL,
@@ -45,29 +33,45 @@ embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
45
 
46
  vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
47
 
48
- retriever = SelfQueryRetriever.from_llm(
49
- model,
50
- vectordb,
51
- document_content_description,
52
- metadata_field_info,
53
- )
 
 
 
 
 
 
 
54
 
55
 
56
  @cl.on_chat_start
57
  async def on_chat_start():
58
 
59
- def format_docs(docs):
60
- return "\n\n".join(doc.page_content for doc in docs)
 
 
 
 
 
 
 
 
61
 
62
  rag_chain = (
63
  {
64
- "context": vectordb.as_retriever() | format_docs,
65
  "question": RunnablePassthrough(),
66
  }
67
  | prompt
68
  | model
69
  | StrOutputParser()
70
  )
 
71
 
72
  cl.user_session.set("rag_chain", rag_chain)
73
 
 
11
  HarmBlockThreshold,
12
  HarmCategory,
13
  )
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.retrievers import ParentDocumentRetriever
16
+ from langchain.storage import InMemoryStore
17
  import config
18
  from prompts import prompt
19
+ import tiktoken
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  model = GoogleGenerativeAI(
22
  model=config.GOOGLE_CHAT_MODEL,
 
33
 
34
  vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
35
 
36
+ ## retriever
37
+
38
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"])
39
+
40
+
41
+ # The storage layer for the parent documents
42
+ store = InMemoryStore()
43
+ retriever = ParentDocumentRetriever(
44
+ vectorstore=vectordb,
45
+ docstore=store,
46
+ child_splitter=text_splitter,
47
+ )
48
+
49
 
50
 
51
  @cl.on_chat_start
52
  async def on_chat_start():
53
 
54
+ def format_docs(documents, max_context_size= 100000, separator= "\n\n"):
55
+ context = ""
56
+ encoder = tiktoken.get_encoding("cl100k_base")
57
+ i=0
58
+ for doc in documents:
59
+ i+=1
60
+ if len(encoder.encode(context)) < max_context_size:
61
+ source = doc.metadata['link']
62
+ context += f"Article{i}:\n"+doc.page_content + f"\nSource: {source}" + separator
63
+ return context
64
 
65
  rag_chain = (
66
  {
67
+ "context": retriever | format_docs,
68
  "question": RunnablePassthrough(),
69
  }
70
  | prompt
71
  | model
72
  | StrOutputParser()
73
  )
74
+
75
 
76
  cl.user_session.set("rag_chain", rag_chain)
77
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ chainlit==1.0.500
4
  chromadb==0.4.24
5
  lark==1.1.9
6
  bs4==0.0.2
7
- selenium==4.19.0
 
 
4
  chromadb==0.4.24
5
  lark==1.1.9
6
  bs4==0.0.2
7
+ selenium==4.19.0
8
+ tiktoken==0.1.1
scrape_data.py CHANGED
@@ -120,24 +120,7 @@ def process_docs(
120
  documents=splits,
121
  embedding=embeddings_model,
122
  persist_directory=persist_directory,
123
- )
124
-
125
- # Indexing data
126
- namespace = "chromadb/my_documents"
127
- record_manager = SQLRecordManager(
128
- namespace, db_url="sqlite:///record_manager_cache.sql"
129
- )
130
- record_manager.create_schema()
131
-
132
- index_result = index(
133
- docs,
134
- record_manager,
135
- doc_search,
136
- cleanup="incremental",
137
- source_id_key="link",
138
- )
139
-
140
- print(f"Indexing stats: {index_result}")
141
 
142
  return doc_search
143
 
 
120
  documents=splits,
121
  embedding=embeddings_model,
122
  persist_directory=persist_directory,
123
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  return doc_search
126