Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os | |
from pinecone import Pinecone, ServerlessSpec | |
from pinecone_text.sparse import BM25Encoder | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from io import BytesIO | |
from base64 import b64encode | |
from tqdm.auto import tqdm | |
from constants import * | |
# initialize connection to pinecone (get API key at app.pinecone.io) | |
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY" | |
# find your environment next to the api key in pinecone console | |
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT" | |
class SearchItem(): | |
def __init__(self, api_key=None, env=None, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
self.api_key = api_key | |
self.environment = env | |
self.pinecone_instance = self.connect_to_pinecone(self.api_key,self.environment) | |
self.index = self.pinecone_instance.Index('clip') | |
self.images, self.metadata = self.load_fashion_dataset() | |
self.clip_model = self.initialize_clip_model(device=device) | |
self.bm25 = self.initialize_bm25_encoder(self.metadata) | |
def connect_to_pinecone(self, api_key, env): | |
api_key = api_key or os.getenv('PINECONE_API_KEY') | |
env = env or os.getenv('PINECONE_ENVIRONMENT') | |
if not api_key or not env: | |
raise ValueError("Pinecone API key and environment are required.") | |
pinecone_instance = Pinecone(api_key=api_key, environment=env) | |
return pinecone_instance | |
def load_fashion_dataset(self): | |
fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
images = fashion["image"] | |
metadata = fashion.remove_columns("image").to_pandas() | |
return images, metadata | |
def initialize_clip_model(self, device='cuda' if torch.cuda.is_available() else 'cpu'): | |
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device) | |
return model | |
def initialize_bm25_encoder(self, metadata): | |
bm25 = BM25Encoder() | |
bm25.fit(metadata['productDisplayName']) | |
return bm25 | |
def hybrid_scale(dense, sparse, alpha=0.05): | |
"""Hybrid vector scaling using a convex combination | |
alpha * dense + (1 - alpha) * sparse | |
Args: | |
dense: Array of floats representing | |
sparse: a dict of `indices` and `values` | |
alpha: float between 0 and 1 where 0 == sparse only | |
and 1 == dense only | |
""" | |
if alpha < 0 or alpha > 1: | |
raise ValueError("Alpha must be between 0 and 1") | |
# Scale sparse and dense vectors to create hybrid search vectors | |
hsparse = { | |
'indices': sparse['indices'], | |
'values': [v * (1 - alpha) for v in sparse['values']] | |
} | |
hdense = [v * alpha for v in dense] | |
return hdense, hsparse | |
if __name__ == "__main__": | |
fashion_processor = SearchItem(api_key, env) | |
query = "blue shoes" | |
# create sparse and dense vectors | |
sparse = fashion_processor.bm25.encode_queries(query) | |
dense = fashion_processor.clip_model.encode(query).tolist() | |
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse) | |
result = fashion_processor.index.query( | |
top_k=5, | |
vector=hdense, | |
sparse_vector=hsparse, | |
include_metadata=True | |
) | |
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]] | |
print('Ok') |