VocRT / qdrent.py
Anurag
update readme
ecaf3da
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import numpy as np
import uuid
import asyncio
# Initialize Qdrant client and collection
client = QdrantClient(host='localhost', port=6333)
collection_name = 'vocRT_collection'
async def initialize_collection():
"""
Initialize the Qdrant collection if it doesn't exist.
"""
collections = client.get_collections().collections
collection_names = [col.name for col in collections]
if collection_name not in collection_names:
# Create the collection if it doesn't exist
client.create_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=768, # Adjust the size to match your embedding dimension
distance=rest.Distance.COSINE
),
)
print(f"Collection '{collection_name}' created.")
else:
print(f"Collection '{collection_name}' already exists.")
asyncio.run(initialize_collection())
async def store_embeddings(session_id, embeddings, texts=None, name="", title="", summary="", categories=""):
"""
Store embeddings for a specific session_id.
Parameters:
- session_id (str): Unique identifier for the session/user.
- embeddings (list of numpy arrays or lists): The embeddings to store.
- texts (list of str, optional): Corresponding text passages for the embeddings.
"""
await initialize_collection()
if texts is not None and len(embeddings) != len(texts):
raise ValueError(
"The number of embeddings and texts must be the same.")
# Upsert embeddings with metadata
points = []
for idx, embedding in enumerate(embeddings):
payload = {'session_id': session_id}
if texts is not None:
payload['text'] = texts[idx]
if name is not None:
payload['filename'] = name
if title is not None:
payload['title'] = title
if summary is not None:
payload['summary'] = summary
if categories is not None:
if isinstance(categories, str):
categories_list = [cat.strip()
for cat in categories.split(',') if cat.strip()]
else:
categories_list = list(categories)
payload['categories'] = categories_list
point_id = str(uuid.uuid4())
point = rest.PointStruct(
id=point_id,
vector=embedding.tolist() if isinstance(embedding, np.ndarray) else embedding,
payload=payload
)
points.append(point)
client.upsert(
collection_name=collection_name,
wait=True,
points=points
)
print(f"Embeddings stored for session_id: {session_id}")
def search_embeddings(session_id, query_embedding, limit=10):
"""
Search embeddings for a specific session_id using a query embedding.
Parameters:
- session_id (str): Unique identifier for the session/user.
- query_embedding (numpy array or list): The query embedding vector.
- limit (int): The number of top results to return.
Returns:
- List of search results, each containing the ID, distance, and payload.
"""
# Ensure query_embedding is a list
if isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.tolist()
# Perform search with session_id filter
results = client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=rest.Filter(
must=[
rest.FieldCondition(
key='session_id',
match=rest.MatchValue(value=session_id)
)
]
),
limit=limit,
with_payload=True
)
return results
def delete_embeddings(session_id):
"""
Delete all embeddings for a specific session_id.
Parameters:
- session_id (str): Unique identifier for the session/user.
"""
# Retrieve all point IDs for the given session_id
point_ids = []
# Scroll through all points matching the session_id
offset = None
while True:
scroll_result = client.scroll(
collection_name=collection_name,
scroll_filter=rest.Filter(
must=[
rest.FieldCondition(
key='session_id',
match=rest.MatchValue(value=session_id)
)
]
),
limit=100,
offset=offset,
with_payload=False
)
points = scroll_result[0]
offset = scroll_result[1]
if not points:
break
point_ids.extend([point.id for point in points])
if offset is None:
break
if point_ids:
try:
client.delete(
collection_name=collection_name,
points_selector=rest.PointIdsList(points=point_ids)
)
print(f"Deleted embeddings for session_id: {session_id}")
return True
except Exception as e:
print("Error in deleting embeddings : ", e)
return False
else:
print(f"No embeddings found for session_id: {session_id}")
return True