mbudisic commited on
Commit
8bc5e76
·
1 Parent(s): 7c5fafc

Async implementation of vector store population

Browse files
app.py CHANGED
@@ -1,17 +1,22 @@
 
1
  from typing import List
2
  import chainlit as cl
3
  import json
 
4
 
 
5
  from langchain_experimental.text_splitter import SemanticChunker
 
6
  from langchain_openai.embeddings import OpenAIEmbeddings
7
  from langchain_core.documents import Document
8
 
9
  from langchain_qdrant import QdrantVectorStore
 
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import Distance, VectorParams
12
  from dataclasses import dataclass
13
 
14
- import pstuts_rag.datastore
15
 
16
 
17
  @dataclass
@@ -19,6 +24,14 @@ class ApplicationParameters:
19
  filename = "data/test.json"
20
  embedding_model = "text-embedding-3-small"
21
  n_context_docs = 2
 
 
 
 
 
 
 
 
22
 
23
 
24
  class ApplicationState:
@@ -26,48 +39,42 @@ class ApplicationState:
26
  docs: List[Document] = []
27
  qdrantclient: QdrantClient = None
28
  vectorstore: QdrantVectorStore = None
29
- retriever = None
 
 
 
 
 
30
 
31
 
32
  state = ApplicationState()
 
33
 
34
 
35
  @cl.on_chat_start
36
  async def on_chat_start():
37
- params = ApplicationParameters()
38
-
39
- await cl.Message(content=f"Loading file {params.filename}").send()
40
- data = json.load(open(params.filename, "rb"))
41
 
42
- state.embeddings = OpenAIEmbeddings(model=params.embedding_model)
43
- state.docs = pstuts_rag.datastore.transcripts_load(data, state.embeddings)
44
- await cl.Message(
45
- content=f"Loaded {len(state.docs)} chunks from file {params.filename}."
46
- ).send()
47
-
48
- state.qdrantclient = QdrantClient(":memory:")
49
-
50
- state.vectorstore = pstuts_rag.datastore.initialize_vectorstore(
51
- client=state.qdrantclient,
52
- collection_name=f"{params.filename}_qdrant",
53
- embeddings=state.embeddings,
54
  )
55
-
56
- _ = state.vectorstore.add_documents(documents=state.docs)
57
- state.retriever = state.vectorstore.as_retriever(
58
- search_kwargs={"k": params.n_context_docs}
 
59
  )
60
-
61
- await cl.Message(content=f"Populated vector database.").send()
62
 
63
 
64
  @cl.on_message
65
  async def main(message: cl.Message):
66
  # Send a response back to the user
67
 
68
- v = await state.retriever.ainvoke(message.content)
69
 
70
- await cl.Message(content=f"Hello! {len(v)}").send()
71
 
72
 
73
  if __name__ == "__main__":
 
1
+ import asyncio
2
  from typing import List
3
  import chainlit as cl
4
  import json
5
+ import os
6
 
7
+ from dotenv import load_dotenv
8
  from langchain_experimental.text_splitter import SemanticChunker
9
+ from langchain_openai import ChatOpenAI
10
  from langchain_openai.embeddings import OpenAIEmbeddings
11
  from langchain_core.documents import Document
12
 
13
  from langchain_qdrant import QdrantVectorStore
14
+ from pstuts_rag.rag import RAGChainFactory, RetrieverFactory
15
  from qdrant_client import QdrantClient
16
  from qdrant_client.http.models import Distance, VectorParams
17
  from dataclasses import dataclass
18
 
19
+ import pstuts_rag.rag
20
 
21
 
22
  @dataclass
 
24
  filename = "data/test.json"
25
  embedding_model = "text-embedding-3-small"
26
  n_context_docs = 2
27
+ llm_model = "gpt-4.1-mini"
28
+
29
+
30
+ def set_api_key_if_not_present(key_name, prompt_message=""):
31
+ if len(prompt_message) == 0:
32
+ prompt_message = key_name
33
+ if key_name not in os.environ or not os.environ[key_name]:
34
+ os.environ[key_name] = getpass.getpass(prompt_message)
35
 
36
 
37
  class ApplicationState:
 
39
  docs: List[Document] = []
40
  qdrantclient: QdrantClient = None
41
  vectorstore: QdrantVectorStore = None
42
+ retriever_factory: pstuts_rag.rag.RetrieverFactory
43
+ rag_factory: pstuts_rag.rag.RAGChainFactory
44
+
45
+ def __init__(self) -> None:
46
+ load_dotenv()
47
+ set_api_key_if_not_present("OPENAI_API_KEY")
48
 
49
 
50
  state = ApplicationState()
51
+ params = ApplicationParameters()
52
 
53
 
54
  @cl.on_chat_start
55
  async def on_chat_start():
56
+ state.client = QdrantClient(":memory:")
 
 
 
57
 
58
+ state.retriever_factory = pstuts_rag.rag.RetrieverFactory(
59
+ qdrant_client=state.client, name="local_test"
 
 
 
 
 
 
 
 
 
 
60
  )
61
+ if state.retriever_factory.count_docs() == 0:
62
+ data: List[Dict[str, Any]] = json.load(open(params.filename, "rb"))
63
+ asyncio.run(main=state.retriever_factory.aadd_docs(raw_docs=data))
64
+ state.rag_factory = pstuts_rag.rag.RAGChainFactory(
65
+ retriever=state.retriever_factory.get_retriever()
66
  )
67
+ state.llm = ChatOpenAI(model=params.llm_model, temperature=0)
68
+ state.rag_chain = state.rag_factory.get_rag_chain(state.llm)
69
 
70
 
71
  @cl.on_message
72
  async def main(message: cl.Message):
73
  # Send a response back to the user
74
 
75
+ v = await state.rag_chain.ainvoke(message.content)
76
 
77
+ await cl.Message(content=v.content).send()
78
 
79
 
80
  if __name__ == "__main__":
notebooks/transcript_rag.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
pstuts_rag/pstuts_rag/datastore.py CHANGED
@@ -1,4 +1,6 @@
 
1
  from typing import List, Dict, Iterator, Any
 
2
 
3
 
4
  from langchain_experimental.text_splitter import SemanticChunker
@@ -7,14 +9,22 @@ from langchain_core.documents import Document
7
 
8
  from .loader import VideoTranscriptBulkLoader, VideoTranscriptLoader
9
 
 
 
10
  from langchain_qdrant import QdrantVectorStore
11
  from qdrant_client import QdrantClient
12
  from qdrant_client.http.models import Distance, VectorParams
 
 
13
 
 
 
 
14
 
15
- def transcripts_load(
 
16
  json_transcripts: List[Dict[str, Any]],
17
- embeddings: OpenAIEmbeddings = OpenAIEmbeddings(
18
  model="text-embedding-3-small"
19
  ),
20
  ) -> List[Document]:
@@ -40,12 +50,21 @@ def transcripts_load(
40
  json_payload=json_transcripts
41
  ).load()
42
 
43
- text_splitter = SemanticChunker(embeddings)
44
-
45
- docs_chunks_semantic: List[Document] = text_splitter.split_documents(
46
- docs_full_transcript
 
 
 
47
  )
 
 
 
 
48
 
 
 
49
  def is_subchunk(a: Document, ofb: Document) -> bool:
50
  return (a.metadata["video_id"] == ofb.metadata["video_id"]) and (
51
  a.page_content in ofb.page_content
@@ -83,35 +102,129 @@ def transcripts_load(
83
  else:
84
  chunk.metadata["start"], chunk.metadata["stop"] = None, None
85
 
 
86
  return docs_chunks_semantic
87
 
88
 
89
- def initialize_vectorstore(
90
- client: QdrantClient, collection_name: str, embeddings: OpenAIEmbeddings
91
- ) -> QdrantVectorStore:
 
 
 
 
 
 
 
 
 
92
  """
93
- Initialize a Qdrant vector store with the given client and collection name.
94
 
95
- This function creates a new collection in Qdrant and initializes a vector
96
- store with the specified embeddings model. The collection is configured
97
- with appropriate vector parameters for the embedding model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- Args:
100
- client: QdrantClient instance to use for connecting to the database
101
- collection_name: Name to use for the new collection
102
- embeddings: OpenAI embeddings model to use for vector encoding
 
 
103
 
104
- Returns:
105
- Initialized QdrantVectorStore instance ready for document storage
106
- """
107
- client.create_collection(
108
- collection_name=collection_name,
109
- vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
110
- )
111
 
112
- vector_store = QdrantVectorStore(
113
- client=client,
114
- collection_name=collection_name,
115
- embedding=embeddings,
116
- )
117
- return vector_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  from typing import List, Dict, Iterator, Any
3
+ import uuid
4
 
5
 
6
  from langchain_experimental.text_splitter import SemanticChunker
 
9
 
10
  from .loader import VideoTranscriptBulkLoader, VideoTranscriptLoader
11
 
12
+ from langchain_core.vectorstores import VectorStoreRetriever
13
+
14
  from langchain_qdrant import QdrantVectorStore
15
  from qdrant_client import QdrantClient
16
  from qdrant_client.http.models import Distance, VectorParams
17
+ from qdrant_client.models import VectorParams, Distance, PointStruct
18
+
19
 
20
+ def batch(iterable: List[Any], size: int = 16) -> Iterator[List[Any]]:
21
+ for i in range(0, len(iterable), size):
22
+ yield iterable[i : i + size]
23
 
24
+
25
+ async def chunk_transcripts(
26
  json_transcripts: List[Dict[str, Any]],
27
+ semantic_chunker_embedding_model: OpenAIEmbeddings = OpenAIEmbeddings(
28
  model="text-embedding-3-small"
29
  ),
30
  ) -> List[Document]:
 
50
  json_payload=json_transcripts
51
  ).load()
52
 
53
+ # semantically split the combined transcript
54
+ text_splitter = SemanticChunker(semantic_chunker_embedding_model)
55
+ docs_group = await asyncio.gather(
56
+ *[
57
+ text_splitter.atransform_documents(d)
58
+ for d in batch(docs_full_transcript, size=2)
59
+ ]
60
  )
61
+ # Flatten the nested list of documents
62
+ docs_chunks_semantic: List[Document] = []
63
+ for group in docs_group:
64
+ docs_chunks_semantic.extend(group)
65
 
66
+ # locate individual sections of the original transcript
67
+ # with the semantic chunks
68
  def is_subchunk(a: Document, ofb: Document) -> bool:
69
  return (a.metadata["video_id"] == ofb.metadata["video_id"]) and (
70
  a.page_content in ofb.page_content
 
102
  else:
103
  chunk.metadata["start"], chunk.metadata["stop"] = None, None
104
 
105
+ docs_chunks_semantic[0].metadata.keys()
106
  return docs_chunks_semantic
107
 
108
 
109
+ class DatastoreManager:
110
+ """Factory class for creating and managing vector store retrievers.
111
+
112
+ This class simplifies the process of creating, populating, and managing
113
+ Qdrant vector stores for document retrieval.
114
+
115
+ Attributes:
116
+ embeddings: OpenAI embeddings model for document vectorization
117
+ docs: List of documents stored in the vector store
118
+ qdrant_client: Client for Qdrant vector database
119
+ name: Unique identifier for this retriever instance
120
+ vector_store: The Qdrant vector store instance
121
  """
 
122
 
123
+ embeddings: OpenAIEmbeddings
124
+ docs: List[Document]
125
+ qdrant_client: QdrantClient
126
+ name: str
127
+ vector_store: QdrantVectorStore
128
+
129
+ def __init__(
130
+ self,
131
+ embeddings: OpenAIEmbeddings = OpenAIEmbeddings(
132
+ model="text-embedding-3-small"
133
+ ),
134
+ qdrant_client: QdrantClient = QdrantClient(location=":memory:"),
135
+ name: str = str(object=uuid.uuid4()),
136
+ ) -> None:
137
+ """Initialize the RetrieverFactory.
138
+
139
+ Args:
140
+ embeddings: OpenAI embeddings model to use
141
+ qdrant_client: Qdrant client for vector database operations
142
+ name: Unique identifier for this retriever instance
143
+ """
144
+ self.embeddings = embeddings
145
+ self.name = name
146
+ self.qdrant_client = qdrant_client
147
+
148
+ self.qdrant_client.recreate_collection(
149
+ collection_name=self.name,
150
+ vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
151
+ )
152
 
153
+ # wrapper around the client
154
+ self.vector_store = QdrantVectorStore(
155
+ client=self.qdrant_client,
156
+ collection_name=self.name,
157
+ embedding=embeddings,
158
+ )
159
 
160
+ self.docs = []
 
 
 
 
 
 
161
 
162
+ async def populate_database(self, raw_docs: List[Dict[str, Any]]):
163
+
164
+ # perform chunking
165
+ self.docs: List[Document] = await chunk_transcripts(
166
+ json_transcripts=raw_docs,
167
+ semantic_chunker_embedding_model=self.embeddings,
168
+ )
169
+
170
+ # perform embedding
171
+
172
+ vector_batches = await asyncio.gather(
173
+ *[
174
+ self.embeddings.aembed_documents(
175
+ [c.page_content for c in chunk_batch]
176
+ )
177
+ for chunk_batch in batch(self.docs, 8)
178
+ ]
179
+ )
180
+ vectors = []
181
+ for vb in vector_batches:
182
+ vectors.extend(vb)
183
+ ids = list(range(len(vectors)))
184
+
185
+ points = [
186
+ PointStruct(
187
+ id=id,
188
+ vector=vector,
189
+ payload={
190
+ "page_content": doc.page_content,
191
+ "metadata": doc.metadata,
192
+ },
193
+ )
194
+ for id, vector, doc in zip(ids, vectors, self.docs)
195
+ ]
196
+
197
+ # upload qdrant payload
198
+ self.qdrant_client.upload_points(
199
+ collection_name=self.name,
200
+ points=points,
201
+ )
202
+
203
+ def count_docs(self) -> int:
204
+ try:
205
+ count = self.qdrant_client.get_collection(self.name).points_count
206
+ return count if count else 0
207
+ except ValueError:
208
+ return 0
209
+
210
+ def clear(self) -> bool:
211
+ """Clear all documents from the vector store.
212
+
213
+ Returns:
214
+ bool: True if deletion was successful, False otherwise
215
+ """
216
+ self.docs = []
217
+ return True if self.vector_store.delete() else False
218
+
219
+ def get_retriever(self, n_context_docs: int = 2) -> VectorStoreRetriever:
220
+ """Get a retriever for the vector store.
221
+
222
+ Args:
223
+ n_context_docs: Number of documents to retrieve for each query
224
+
225
+ Returns:
226
+ VectorStoreRetriever: The configured retriever
227
+ """
228
+ return self.vector_store.as_retriever(
229
+ search_kwargs={"k": n_context_docs}
230
+ )
pstuts_rag/pstuts_rag/rag.py CHANGED
@@ -6,6 +6,7 @@ This module provides the core RAG functionality, including:
6
  """
7
 
8
  import json
 
9
  import uuid
10
  from operator import itemgetter
11
  from typing import Dict, List, Any
@@ -25,96 +26,12 @@ from langchain.prompts import ChatPromptTemplate
25
  from langchain_core.vectorstores import VectorStoreRetriever
26
  from langchain_openai import ChatOpenAI
27
 
28
- from .datastore import initialize_vectorstore, transcripts_load
29
  from .prompt_templates import RAG_PROMPT_TEMPLATES
30
 
31
  from langchain_core.language_models.base import BaseLanguageModel
32
  from langchain_core.messages import AIMessage
33
 
34
 
35
- class RetrieverFactory:
36
- """Factory class for creating and managing vector store retrievers.
37
-
38
- This class simplifies the process of creating, populating, and managing
39
- Qdrant vector stores for document retrieval.
40
-
41
- Attributes:
42
- embeddings: OpenAI embeddings model for document vectorization
43
- docs: List of documents stored in the vector store
44
- qdrant_client: Client for Qdrant vector database
45
- name: Unique identifier for this retriever instance
46
- vector_store: The Qdrant vector store instance
47
- """
48
-
49
- embeddings: OpenAIEmbeddings
50
- docs: List[Document]
51
- qdrant_client: QdrantClient
52
- name: str
53
- vector_store: QdrantVectorStore
54
-
55
- def __init__(
56
- self,
57
- embeddings: OpenAIEmbeddings = OpenAIEmbeddings(
58
- model="text-embedding-3-small"
59
- ),
60
- qdrant_client: QdrantClient = QdrantClient(location=":memory:"),
61
- name: str = str(object=uuid.uuid4()),
62
- ) -> None:
63
- """Initialize the RetrieverFactory.
64
-
65
- Args:
66
- embeddings: OpenAI embeddings model to use
67
- qdrant_client: Qdrant client for vector database operations
68
- name: Unique identifier for this retriever instance
69
- """
70
- self.embeddings = embeddings
71
- self.name = name
72
- self.qdrant_client = qdrant_client
73
- self.vector_store = initialize_vectorstore(
74
- client=self.qdrant_client,
75
- collection_name=f"{self.name}_qdrant",
76
- embeddings=self.embeddings,
77
- )
78
- self.docs = []
79
-
80
- def add_docs(self, raw_docs: List[Dict[str, Any]]) -> None:
81
- """Add documents to the vector store.
82
-
83
- Takes raw document data, converts it to Document objects,
84
- and adds them to the vector store.
85
-
86
- Args:
87
- raw_docs: List of raw document dictionaries
88
- """
89
- docs: List[Document] = transcripts_load(
90
- json_transcripts=raw_docs, embeddings=self.embeddings
91
- )
92
- self.docs.extend(docs)
93
- _ = self.vector_store.add_documents(documents=docs)
94
-
95
- def clear(self) -> bool:
96
- """Clear all documents from the vector store.
97
-
98
- Returns:
99
- bool: True if deletion was successful, False otherwise
100
- """
101
- self.docs = []
102
- return True if self.vector_store.delete() else False
103
-
104
- def get_retriever(self, n_context_docs: int = 2) -> VectorStoreRetriever:
105
- """Get a retriever for the vector store.
106
-
107
- Args:
108
- n_context_docs: Number of documents to retrieve for each query
109
-
110
- Returns:
111
- VectorStoreRetriever: The configured retriever
112
- """
113
- return self.vector_store.as_retriever(
114
- search_kwargs={"k": n_context_docs}
115
- )
116
-
117
-
118
  class RAGChainFactory:
119
  """Factory class for creating RAG (Retrieval Augmented Generation) chains.
120
 
 
6
  """
7
 
8
  import json
9
+ from multiprocessing import Value
10
  import uuid
11
  from operator import itemgetter
12
  from typing import Dict, List, Any
 
26
  from langchain_core.vectorstores import VectorStoreRetriever
27
  from langchain_openai import ChatOpenAI
28
 
 
29
  from .prompt_templates import RAG_PROMPT_TEMPLATES
30
 
31
  from langchain_core.language_models.base import BaseLanguageModel
32
  from langchain_core.messages import AIMessage
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class RAGChainFactory:
36
  """Factory class for creating RAG (Retrieval Augmented Generation) chains.
37
 
pyproject.toml CHANGED
@@ -68,14 +68,16 @@ known-first-party = ["src"]
68
  line-length = 79
69
  target-version = ["py313"]
70
 
 
 
 
 
 
 
 
 
71
  [tool.mypy]
72
- python_version = "3.13"
73
- warn_return_any = true
74
- warn_unused_configs = true
75
- disallow_untyped_defs = true
76
- mypy_path = ["pstuts_rag/pstuts_rag"]
77
- namespace_packages = true
78
- explicit_package_bases = true
79
 
80
  [tool.flake8]
81
  application-import-names = "pstuts_rag"
 
68
  line-length = 79
69
  target-version = ["py313"]
70
 
71
+ # [tool.mypy]
72
+ # python_version = "3.13"
73
+ # warn_return_any = true
74
+ # warn_unused_configs = true
75
+ # disallow_untyped_defs = true
76
+ # mypy_path = ["pstuts_rag/pstuts_rag"]
77
+ # namespace_packages = true
78
+ # explicit_package_bases = true
79
  [tool.mypy]
80
+ ignore_errors = true
 
 
 
 
 
 
81
 
82
  [tool.flake8]
83
  application-import-names = "pstuts_rag"