|
from qdrant_client import QdrantClient |
|
from qdrant_client.http import models as rest |
|
import numpy as np |
|
import uuid |
|
import asyncio |
|
|
|
|
|
client = QdrantClient(host='localhost', port=6333) |
|
collection_name = 'vocRT_collection' |
|
|
|
|
|
async def initialize_collection(): |
|
""" |
|
Initialize the Qdrant collection if it doesn't exist. |
|
""" |
|
collections = client.get_collections().collections |
|
collection_names = [col.name for col in collections] |
|
|
|
if collection_name not in collection_names: |
|
|
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=rest.VectorParams( |
|
size=768, |
|
distance=rest.Distance.COSINE |
|
), |
|
) |
|
print(f"Collection '{collection_name}' created.") |
|
else: |
|
print(f"Collection '{collection_name}' already exists.") |
|
|
|
|
|
asyncio.run(initialize_collection()) |
|
|
|
async def store_embeddings(session_id, embeddings, texts=None, name="", title="", summary="", categories=""): |
|
""" |
|
Store embeddings for a specific session_id. |
|
|
|
Parameters: |
|
- session_id (str): Unique identifier for the session/user. |
|
- embeddings (list of numpy arrays or lists): The embeddings to store. |
|
- texts (list of str, optional): Corresponding text passages for the embeddings. |
|
""" |
|
|
|
await initialize_collection() |
|
|
|
if texts is not None and len(embeddings) != len(texts): |
|
raise ValueError( |
|
"The number of embeddings and texts must be the same.") |
|
|
|
|
|
points = [] |
|
for idx, embedding in enumerate(embeddings): |
|
payload = {'session_id': session_id} |
|
if texts is not None: |
|
payload['text'] = texts[idx] |
|
if name is not None: |
|
payload['filename'] = name |
|
if title is not None: |
|
payload['title'] = title |
|
if summary is not None: |
|
payload['summary'] = summary |
|
if categories is not None: |
|
if isinstance(categories, str): |
|
categories_list = [cat.strip() |
|
for cat in categories.split(',') if cat.strip()] |
|
else: |
|
categories_list = list(categories) |
|
|
|
payload['categories'] = categories_list |
|
|
|
point_id = str(uuid.uuid4()) |
|
point = rest.PointStruct( |
|
id=point_id, |
|
vector=embedding.tolist() if isinstance(embedding, np.ndarray) else embedding, |
|
payload=payload |
|
) |
|
points.append(point) |
|
|
|
client.upsert( |
|
collection_name=collection_name, |
|
wait=True, |
|
points=points |
|
) |
|
print(f"Embeddings stored for session_id: {session_id}") |
|
|
|
|
|
def search_embeddings(session_id, query_embedding, limit=10): |
|
""" |
|
Search embeddings for a specific session_id using a query embedding. |
|
|
|
Parameters: |
|
- session_id (str): Unique identifier for the session/user. |
|
- query_embedding (numpy array or list): The query embedding vector. |
|
- limit (int): The number of top results to return. |
|
|
|
Returns: |
|
- List of search results, each containing the ID, distance, and payload. |
|
""" |
|
|
|
if isinstance(query_embedding, np.ndarray): |
|
query_embedding = query_embedding.tolist() |
|
|
|
|
|
results = client.search( |
|
collection_name=collection_name, |
|
query_vector=query_embedding, |
|
query_filter=rest.Filter( |
|
must=[ |
|
rest.FieldCondition( |
|
key='session_id', |
|
match=rest.MatchValue(value=session_id) |
|
) |
|
] |
|
), |
|
limit=limit, |
|
with_payload=True |
|
) |
|
return results |
|
|
|
|
|
def delete_embeddings(session_id): |
|
""" |
|
Delete all embeddings for a specific session_id. |
|
|
|
Parameters: |
|
- session_id (str): Unique identifier for the session/user. |
|
""" |
|
|
|
point_ids = [] |
|
|
|
|
|
offset = None |
|
while True: |
|
scroll_result = client.scroll( |
|
collection_name=collection_name, |
|
scroll_filter=rest.Filter( |
|
must=[ |
|
rest.FieldCondition( |
|
key='session_id', |
|
match=rest.MatchValue(value=session_id) |
|
) |
|
] |
|
), |
|
limit=100, |
|
offset=offset, |
|
with_payload=False |
|
) |
|
points = scroll_result[0] |
|
offset = scroll_result[1] |
|
if not points: |
|
break |
|
point_ids.extend([point.id for point in points]) |
|
if offset is None: |
|
break |
|
|
|
if point_ids: |
|
try: |
|
client.delete( |
|
collection_name=collection_name, |
|
points_selector=rest.PointIdsList(points=point_ids) |
|
) |
|
print(f"Deleted embeddings for session_id: {session_id}") |
|
return True |
|
except Exception as e: |
|
print("Error in deleting embeddings : ", e) |
|
return False |
|
else: |
|
print(f"No embeddings found for session_id: {session_id}") |
|
return True |
|
|