momenaca's picture
update structure of the project for clarity
3650955
raw
history blame
1.19 kB
from qdrant_client.http import models
import pickle as pickle
import torch
import io
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
class Device_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
return lambda b: torch.load(io.BytesIO(b), map_location=device_str)
else:
return super().find_class(module, name)
def pickle_to_document_store(path):
with open(path, "rb") as f:
document_store = Device_Unpickler(f).load()
document_store.embeddings.encode_kwargs["device"] = device_str
return document_store
def get_qdrant_filters(filter_dict: dict):
"""Build a Qdrant filter based on a filter dict.
Filter dict must use metadata fields and be formated like:
filter_dict = {'file_name':['file1', 'file2'],'sub_type':['text']}
"""
return models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{field}",
match=models.MatchAny(any=filter_dict[field]),
)
for field in filter_dict
]
)