Spaces:
Runtime error
Runtime error
File size: 3,610 Bytes
f02b11f 76879ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from pinecone import Pinecone, ServerlessSpec
from pinecone_text.sparse import BM25Encoder
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from io import BytesIO
from base64 import b64encode
from tqdm.auto import tqdm
from constants import *
# initialize connection to pinecone (get API key at app.pinecone.io)
api_key = PINECONE_API_KEY or os.getenv(PINECONE_API_KEY) # or "PINECONE_API_KEY"
# find your environment next to the api key in pinecone console
env = PINECONE_ENVIRONMENT or os.getenv(PINECONE_ENVIRONMENT) # or "PINECONE_ENVIRONMENT"
class SearchItem():
def __init__(self, api_key=None, env=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.api_key = api_key
self.environment = env
self.pinecone_instance = self.connect_to_pinecone(self.api_key,self.environment)
self.index = self.pinecone_instance.Index('clip')
self.images, self.metadata = self.load_fashion_dataset()
self.clip_model = self.initialize_clip_model(device=device)
self.bm25 = self.initialize_bm25_encoder(self.metadata)
def connect_to_pinecone(self, api_key, env):
api_key = api_key or os.getenv('PINECONE_API_KEY')
env = env or os.getenv('PINECONE_ENVIRONMENT')
if not api_key or not env:
raise ValueError("Pinecone API key and environment are required.")
pinecone_instance = Pinecone(api_key=api_key, environment=env)
return pinecone_instance
def load_fashion_dataset(self):
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
images = fashion["image"]
metadata = fashion.remove_columns("image").to_pandas()
return images, metadata
def initialize_clip_model(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32', device=device)
return model
def initialize_bm25_encoder(self, metadata):
bm25 = BM25Encoder()
bm25.fit(metadata['productDisplayName'])
return bm25
@staticmethod
def hybrid_scale(dense, sparse, alpha=0.05):
"""Hybrid vector scaling using a convex combination
alpha * dense + (1 - alpha) * sparse
Args:
dense: Array of floats representing
sparse: a dict of `indices` and `values`
alpha: float between 0 and 1 where 0 == sparse only
and 1 == dense only
"""
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
# Scale sparse and dense vectors to create hybrid search vectors
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
if __name__ == "__main__":
fashion_processor = SearchItem(api_key, env)
query = "blue shoes"
# create sparse and dense vectors
sparse = fashion_processor.bm25.encode_queries(query)
dense = fashion_processor.clip_model.encode(query).tolist()
hdense, hsparse = fashion_processor.hybrid_scale(dense, sparse)
result = fashion_processor.index.query(
top_k=5,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True
)
imgs = [fashion_processor.images[int(r["id"])] for r in result["matches"]]
print('Ok') |