Spaces:
Sleeping
Sleeping
from qdrant_client.http import models | |
import pickle as pickle | |
import torch | |
import io | |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu" | |
device = torch.device(device_str) | |
class Device_Unpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
if module == "torch.storage" and name == "_load_from_bytes": | |
return lambda b: torch.load(io.BytesIO(b), map_location=device_str) | |
else: | |
return super().find_class(module, name) | |
def pickle_to_document_store(path): | |
with open(path, "rb") as f: | |
document_store = Device_Unpickler(f).load() | |
document_store.embeddings.encode_kwargs["device"] = device_str | |
return document_store | |
def get_qdrant_filters(filter_dict: dict): | |
"""Build a Qdrant filter based on a filter dict. | |
Filter dict must use metadata fields and be formated like: | |
filter_dict = {'file_name':['file1', 'file2'],'sub_type':['text']} | |
""" | |
return models.Filter( | |
must=[ | |
models.FieldCondition( | |
key=f"metadata.{field}", | |
match=models.MatchAny(any=filter_dict[field]), | |
) | |
for field in filter_dict | |
] | |
) | |