mbudisic commited on
Commit
b4c3986
·
1 Parent(s): be858e2

chore: Update .gitignore and refactor DatastoreManager for Qdrant integration

Browse files

- Added 'qdrant/' to .gitignore to prevent tracking of Qdrant-related files.
- Refactored `DatastoreManager` to utilize `QdrantClientSingleton` for thread-safe Qdrant client management.
- Enhanced collection creation logic to handle existing collections gracefully, improving robustness and logging.

Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +10 -5
  3. docs/DEVELOPER.md +25 -0
  4. pstuts_rag/pstuts_rag/datastore.py +73 -41
.gitignore CHANGED
@@ -8,3 +8,4 @@ __pycache__/
8
  .embeddings_cache/
9
  notebooks/*/
10
  *.pckl
 
 
8
  .embeddings_cache/
9
  notebooks/*/
10
  *.pckl
11
+ qdrant/
app.py CHANGED
@@ -48,11 +48,16 @@ async def on_chat_start():
48
  thread_id = f"chat_{uuid4().hex[:8]}"
49
  configuration.thread_id = thread_id
50
 
51
- datastore = await asyncio.to_thread(
52
- lambda: DatastoreManager(config=configuration).add_completion_callback(
53
- lambda: cl.run_sync(
54
- cl.Message(content="Datastore loading completed.").send()
55
- )
 
 
 
 
 
56
  )
57
  )
58
 
 
48
  thread_id = f"chat_{uuid4().hex[:8]}"
49
  configuration.thread_id = thread_id
50
 
51
+ # datastore = await asyncio.to_thread(
52
+ # lambda: DatastoreManager(config=configuration).add_completion_callback(
53
+ # lambda: cl.run_sync(
54
+ # cl.Message(content="Datastore loading completed.").send()
55
+ # )
56
+ # )
57
+ datastore = DatastoreManager(config=configuration)
58
+ datastore.add_completion_callback(
59
+ lambda: cl.run_sync(
60
+ cl.Message(content="Datastore loading completed.").send()
61
  )
62
  )
63
 
docs/DEVELOPER.md CHANGED
@@ -89,6 +89,31 @@ pip install -e ".[dev,web]" # Core + dev + web server
89
  - **`RAG for Transcripts`** (`rag_for_transcripts.py`): Implements the RAG chain for searching video transcripts, including reference packing and post-processing for AIMessage responses. Used for context-rich, reference-annotated answers from video data. 🎬
90
  - **`Graph Assembly`** (`graph.py`): Handles agent node creation, LangGraph assembly, and integration of multi-agent workflows. Provides utilities for building, initializing, and running the agentic graph. 🕸️
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  #### 🕸️ Multi-Agent System
93
  - **`PsTutsTeamState`** (`state.py`): TypedDict managing multi-agent conversation state
94
  - **Agent creation functions** (`graph.py`): Factory functions for different agent types:
 
89
  - **`RAG for Transcripts`** (`rag_for_transcripts.py`): Implements the RAG chain for searching video transcripts, including reference packing and post-processing for AIMessage responses. Used for context-rich, reference-annotated answers from video data. 🎬
90
  - **`Graph Assembly`** (`graph.py`): Handles agent node creation, LangGraph assembly, and integration of multi-agent workflows. Provides utilities for building, initializing, and running the agentic graph. 🕸️
91
 
92
+ #### 🗄️ QdrantClientSingleton (datastore.py)
93
+ - **Purpose:** Ensures only one instance of QdrantClient exists per process, preventing accidental concurrent access to embedded Qdrant. Thread-safe and logs every access!
94
+ - **Usage:**
95
+ ```python
96
+ from pstuts_rag.datastore import QdrantClientSingleton
97
+ client = QdrantClientSingleton.get_client(path="/path/to/db") # or path=None for in-memory
98
+ ```
99
+ - **Behavior:**
100
+ - First call determines the storage location (persistent or in-memory)
101
+ - All subsequent calls return the same client, regardless of path
102
+ - Thread-safe via a lock
103
+ - Every call logs the requested path for debugging 🪵
104
+
105
+ #### 🏪 DatastoreManager (datastore.py)
106
+ - **Collection Creation Logic:**
107
+ - On initialization, always tries to create the Qdrant collection for the vector store.
108
+ - If the collection already exists, catches the `ValueError` and simply fetches the existing collection instead (no crash, no duplicate creation!).
109
+ - This is the recommended robust pattern for Qdrant local mode. 🦺
110
+ - Example log output:
111
+ ```
112
+ Collection EVA_AI_transcripts created.
113
+ # or
114
+ Collection EVA_AI_transcripts already exists.
115
+ ```
116
+
117
  #### 🕸️ Multi-Agent System
118
  - **`PsTutsTeamState`** (`state.py`): TypedDict managing multi-agent conversation state
119
  - **Agent creation functions** (`graph.py`): Factory functions for different agent types:
pstuts_rag/pstuts_rag/datastore.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  from typing import List, Dict, Iterator, Any, Callable, Optional, Self
8
  import uuid
9
  import logging
 
10
 
11
  import chainlit as cl
12
  from langchain_core.document_loaders import BaseLoader
@@ -28,6 +29,37 @@ from pstuts_rag.utils import get_embeddings_api, flatten, batch
28
  from pathvalidate import sanitize_filename, sanitize_filepath
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class DatastoreManager:
32
  """Factory class for creating and managing vector store retrievers.
33
 
@@ -47,6 +79,7 @@ class DatastoreManager:
47
  embeddings: Embeddings
48
  docs: List[Document]
49
  qdrant_client: QdrantClient
 
50
  name: str
51
  vector_store: QdrantVectorStore
52
  dimensions: int
@@ -59,7 +92,7 @@ class DatastoreManager:
59
  self,
60
  embeddings: Optional[Embeddings] = None,
61
  qdrant_client: QdrantClient | None = None,
62
- name: str = str(object=uuid.uuid4()),
63
  config: Configuration = Configuration(),
64
  ) -> None:
65
  """Initialize the RetrieverFactory.
@@ -71,7 +104,6 @@ class DatastoreManager:
71
  """
72
 
73
  if embeddings is None:
74
-
75
  cls = get_embeddings_api(config.embedding_api)
76
  self.embeddings = cls(model=config.embedding_model)
77
  else:
@@ -80,33 +112,22 @@ class DatastoreManager:
80
  self.name = name if name else config.eva_workflow_name
81
 
82
  if qdrant_client is None:
83
-
84
- try:
85
- if (
86
- config.db_persist
87
- and isinstance(config.db_persist, str)
88
- and len(config.db_persist) > 0
89
- ):
90
- qdrant_path = Path(
91
- sanitize_filepath(config.db_persist)
92
- ) / sanitize_filename(config.embedding_model)
93
- logging.info(
94
- "Persisting the datastore to: %s",
95
- str(qdrant_path),
96
- )
97
-
98
- qdrant_path.mkdir(parents=True, exist_ok=True)
99
-
100
- qdrant_client = QdrantClient(path=str(qdrant_path))
101
- except (OSError, ValueError) as e:
102
- logging.error(
103
- "Persistence aborted, exception occurred: %s: %s",
104
- type(e).__name__,
105
- str(e),
106
  )
107
- finally:
108
- if qdrant_client is None:
109
- qdrant_client = QdrantClient(location=":memory:")
110
 
111
  self.qdrant_client = qdrant_client
112
  atexit.register(qdrant_client.close)
@@ -116,18 +137,24 @@ class DatastoreManager:
116
 
117
  # determine embedding dimension
118
  self.dimensions = len(self.embeddings.embed_query("test"))
119
-
120
- self.qdrant_client.recreate_collection(
121
- collection_name=self.name,
122
- vectors_config=VectorParams(
123
- size=self.dimensions, distance=Distance.COSINE
124
- ),
125
- )
 
 
 
 
 
 
126
 
127
  # wrapper around the client
128
  self.vector_store = QdrantVectorStore(
129
  client=self.qdrant_client,
130
- collection_name=self.name,
131
  embedding=self.embeddings,
132
  )
133
 
@@ -219,10 +246,13 @@ class DatastoreManager:
219
  ]
220
 
221
  # upload qdrant payload
222
- self.qdrant_client.upload_points(
223
- collection_name=self.name,
224
- points=points,
225
- )
 
 
 
226
 
227
  return len(points)
228
 
@@ -238,7 +268,9 @@ class DatastoreManager:
238
  This method is safe to call even if the collection doesn't exist
239
  """
240
  try:
241
- count = self.qdrant_client.get_collection(self.name).points_count
 
 
242
  return count if count else 0
243
  except ValueError:
244
  return 0
 
7
  from typing import List, Dict, Iterator, Any, Callable, Optional, Self
8
  import uuid
9
  import logging
10
+ import threading
11
 
12
  import chainlit as cl
13
  from langchain_core.document_loaders import BaseLoader
 
29
  from pathvalidate import sanitize_filename, sanitize_filepath
30
 
31
 
32
+ class QdrantClientSingleton:
33
+ """
34
+ Thread-safe singleton for QdrantClient. Ignores path changes after first initialization.
35
+ Logs every invocation of get_client.
36
+ """
37
+
38
+ _instance = None
39
+ _lock = threading.Lock()
40
+ _config = None
41
+
42
+ @classmethod
43
+ def get_client(cls, path=None):
44
+ import logging
45
+
46
+ logging.info(
47
+ f"QdrantClientSingleton.get_client called with path={path!r}"
48
+ )
49
+ from qdrant_client import QdrantClient
50
+
51
+ with cls._lock:
52
+ if cls._instance is None:
53
+ if path is None:
54
+ cls._instance = QdrantClient(location=":memory:")
55
+ cls._config = ":memory:"
56
+ else:
57
+ cls._instance = QdrantClient(path=path)
58
+ cls._config = path
59
+ # Ignore any subsequent path changes, always return the first-initialized client
60
+ return cls._instance
61
+
62
+
63
  class DatastoreManager:
64
  """Factory class for creating and managing vector store retrievers.
65
 
 
79
  embeddings: Embeddings
80
  docs: List[Document]
81
  qdrant_client: QdrantClient
82
+ collection_name: str
83
  name: str
84
  vector_store: QdrantVectorStore
85
  dimensions: int
 
92
  self,
93
  embeddings: Optional[Embeddings] = None,
94
  qdrant_client: QdrantClient | None = None,
95
+ name: str = "EVA_AI",
96
  config: Configuration = Configuration(),
97
  ) -> None:
98
  """Initialize the RetrieverFactory.
 
104
  """
105
 
106
  if embeddings is None:
 
107
  cls = get_embeddings_api(config.embedding_api)
108
  self.embeddings = cls(model=config.embedding_model)
109
  else:
 
112
  self.name = name if name else config.eva_workflow_name
113
 
114
  if qdrant_client is None:
115
+ # Use the singleton for QdrantClient
116
+ path = None
117
+ if (
118
+ config.db_persist
119
+ and isinstance(config.db_persist, str)
120
+ and len(config.db_persist) > 0
121
+ ):
122
+ qdrant_path = Path(
123
+ sanitize_filepath(config.db_persist)
124
+ ) / sanitize_filename(config.embedding_model)
125
+ logging.info(
126
+ "Persisting the datastore to: %s", str(qdrant_path)
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
+ qdrant_path.mkdir(parents=True, exist_ok=True)
129
+ path = str(qdrant_path)
130
+ qdrant_client = QdrantClientSingleton.get_client(path=path)
131
 
132
  self.qdrant_client = qdrant_client
133
  atexit.register(qdrant_client.close)
 
137
 
138
  # determine embedding dimension
139
  self.dimensions = len(self.embeddings.embed_query("test"))
140
+ self.collection_name = self.name + "_transcripts"
141
+ # Try to create the collection, fall back to get_collection if it already exists
142
+ try:
143
+ self.qdrant_client.create_collection(
144
+ collection_name=self.collection_name,
145
+ vectors_config=VectorParams(
146
+ size=self.dimensions, distance=Distance.COSINE
147
+ ),
148
+ )
149
+ logging.info(f"Collection {self.collection_name} created.")
150
+ except ValueError:
151
+ self.qdrant_client.get_collection(self.collection_name)
152
+ logging.info(f"Collection {self.collection_name} already exists.")
153
 
154
  # wrapper around the client
155
  self.vector_store = QdrantVectorStore(
156
  client=self.qdrant_client,
157
+ collection_name=self.collection_name,
158
  embedding=self.embeddings,
159
  )
160
 
 
246
  ]
247
 
248
  # upload qdrant payload
249
+ if self.count_docs() == len(points):
250
+ logging.info("Qdrant database populated. Skipping upload")
251
+ else:
252
+ self.qdrant_client.upload_points(
253
+ collection_name=self.collection_name,
254
+ points=points,
255
+ )
256
 
257
  return len(points)
258
 
 
268
  This method is safe to call even if the collection doesn't exist
269
  """
270
  try:
271
+ count = self.qdrant_client.get_collection(
272
+ self.collection_name
273
+ ).points_count
274
  return count if count else 0
275
  except ValueError:
276
  return 0