Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image as PILImage | |
import torch | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from io import BytesIO | |
import requests | |
from beartype.typing import Any, Hashable | |
from requests import RequestException | |
from superlinked import framework as sl | |
from datasets import load_dataset | |
# Constants | |
DATASET_ID = "tomytjandra/h-and-m-fashion-caption" | |
VIT_MODEL_ID = "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
NUM_SAMPLES = 200 | |
SEED = 42 | |
LIMIT = 3 | |
# Load and prepare dataset | |
fashion_dataset = load_dataset(DATASET_ID) | |
fashion_sample_dataset = fashion_dataset["train"].shuffle(seed=SEED).select(range(NUM_SAMPLES)) | |
# Organize metadata | |
fashion_json_data = [item for i, item in enumerate(fashion_sample_dataset)] | |
for i, item in enumerate(fashion_json_data): | |
fashion_json_data[i]["id"] = i | |
fashion_df = pd.DataFrame(fashion_json_data) | |
fashion_df["description"] = fashion_df["text"] | |
json_data = fashion_df.to_dict(orient="records") | |
# Superlinked setup | |
class Image(sl.Schema): | |
id: sl.IdField | |
image: sl.Blob | |
description: sl.String | |
image = Image() | |
image_embedding_space = sl.ImageSpace(image=image.image, model=VIT_MODEL_ID, model_handler=sl.ModelHandler.OPEN_CLIP) | |
description_space = sl.TextSimilaritySpace(text=image.description, model="Alibaba-NLP/gte-large-en-v1.5") | |
composite_index = sl.Index([image_embedding_space, description_space]) | |
source = sl.InMemorySource(image) | |
executor = sl.InMemoryExecutor(sources=[source], indices=[composite_index]) | |
app = executor.run() | |
source.put(json_data) | |
# Query construction | |
combined_query = ( | |
sl.Query( | |
composite_index, | |
weights={ | |
description_space: sl.Param("description_weight"), | |
image_embedding_space: sl.Param("image_embedding_weight"), | |
}, | |
) | |
.find(image) | |
.similar(description_space, sl.Param("text_search")) | |
.similar(image_embedding_space.image, sl.Param("image_search")) | |
.similar(image_embedding_space.description, sl.Param("text_in_image_search")) | |
.select_all() | |
.limit(3) | |
) | |
def process_search_results(results_df, dataset, similarity_threshold=0.5): | |
""" | |
Process search results with filtering and enhanced descriptions | |
""" | |
filtered_df = results_df[results_df['similarity_score'] >= similarity_threshold] | |
if filtered_df.empty: | |
return { | |
"images": [], | |
"descriptions": [], | |
"similarity_plot": None, | |
"error": "No results meet the similarity threshold" | |
} | |
images = [] | |
descriptions = [] | |
scores = [] | |
for _, row in filtered_df.iterrows(): | |
product_id = int(row['id']) | |
img = dataset["image"][product_id] | |
if isinstance(img, np.ndarray): | |
img = PILImage.fromarray(img) | |
product_info = dataset[product_id] | |
description = { | |
"Product ID": str(product_id), | |
"Description": product_info.get("text", "N/A"), | |
"Category": product_info.get("category", "N/A"), | |
"Similarity Score": f"{float(row['similarity_score']):.3f}", | |
"Price Range": product_info.get("price_range", "N/A"), | |
"Colors": product_info.get("colors", []), | |
"Brand": product_info.get("brand", "N/A") | |
} | |
images.append(img) | |
descriptions.append(description) | |
scores.append(float(row['similarity_score'])) | |
similarity_plot = create_similarity_visualization(scores, descriptions) | |
return { | |
"images": images, | |
"descriptions": descriptions, | |
"similarity_plot": similarity_plot, | |
"error": None | |
} | |
def create_similarity_visualization(scores, descriptions): | |
""" | |
Create an enhanced visualization of similarity scores | |
""" | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[2, 1]) | |
bars = ax1.bar(range(len(scores)), scores) | |
ax1.set_title('Similarity Scores Distribution') | |
ax1.set_xlabel('Result Index') | |
ax1.set_ylabel('Similarity Score') | |
for bar in bars: | |
height = bar.get_height() | |
ax1.text(bar.get_x() + bar.get_width()/2., height, | |
f'{height:.3f}', | |
ha='center', va='bottom') | |
ax1.axhline(y=0.5, color='r', linestyle='--', label='Threshold') | |
ax1.legend() | |
categories = [d.get("Category") for d in descriptions] | |
unique_categories = list(set(categories)) | |
category_counts = [categories.count(cat) for cat in unique_categories] | |
ax2.pie(category_counts, labels=unique_categories, autopct='%1.1f%%') | |
ax2.set_title('Category Distribution') | |
plt.tight_layout() | |
return fig | |
def search_products(search_text, search_image, search_type, weight_text=1.0, weight_image=1.0, similarity_threshold=0.5): | |
""" | |
Enhanced search function with filtering and error handling | |
""" | |
try: | |
if search_type == "Text Only" and not search_text: | |
raise ValueError("Please enter search text") | |
if search_type == "Image Only" and search_image is None: | |
raise ValueError("Please upload an image") | |
if search_type == "Text Only": | |
results = app.query( | |
combined_query, | |
description_weight=1, | |
text_search=search_text | |
) | |
elif search_type == "Image Only": | |
if isinstance(search_image, np.ndarray): | |
search_image = PILImage.fromarray(search_image) | |
results = app.query( | |
combined_query, | |
image_embedding_weight=1, | |
image_search=search_image | |
) | |
else: # Combined Search | |
if isinstance(search_image, np.ndarray): | |
search_image = PILImage.fromarray(search_image) | |
results = app.query( | |
combined_query, | |
description_weight=weight_text, | |
image_embedding_weight=weight_image, | |
text_search=search_text, | |
image_search=search_image | |
) | |
results_df = results.to_pandas() | |
return process_search_results(results_df, fashion_sample_dataset, similarity_threshold) | |
except Exception as e: | |
return { | |
"images": [], | |
"descriptions": [], | |
"similarity_plot": None, | |
"error": str(e) | |
} | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Fashion Product Semantic Search") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
text_input = gr.Textbox( | |
label="Search Text", | |
placeholder="Enter product description..." | |
) | |
image_input = gr.Image( | |
label="Search Image", | |
type="pil" | |
) | |
search_type = gr.Radio( | |
choices=["Text Only", "Image Only", "Combined Search"], | |
label="Search Type", | |
value="Text Only" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
similarity_threshold = gr.Slider( | |
minimum=0, maximum=1, value=0.5, | |
label="Similarity Threshold" | |
) | |
with gr.Row(visible=False) as weight_controls: | |
text_weight = gr.Slider( | |
minimum=0, maximum=2, value=1, | |
label="Text Weight" | |
) | |
image_weight = gr.Slider( | |
minimum=0, maximum=2, value=1, | |
label="Image Weight" | |
) | |
search_button = gr.Button("Search", variant="primary") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gallery = gr.Gallery( | |
label="Search Results", | |
columns=3, | |
height="400px" | |
) | |
product_details = gr.JSON( | |
label="Product Details" | |
) | |
similarity_plot = gr.Plot( | |
label="Similarity Analysis" | |
) | |
error_display = gr.Textbox( | |
label="Status", | |
visible=False | |
) | |
def handle_search(text, image, search_type, text_weight, image_weight, threshold): | |
results = search_products(text, image, search_type, text_weight, image_weight, threshold) | |
if results["error"]: | |
return { | |
error_display: gr.update(value=results["error"], visible=True), | |
gallery: None, | |
product_details: None, | |
similarity_plot: None | |
} | |
return { | |
error_display: gr.update(visible=False), | |
gallery: results["images"], | |
product_details: results["descriptions"] if results["descriptions"] else None, | |
similarity_plot: results["similarity_plot"] | |
} | |
search_type.change( | |
fn=lambda x: gr.Row.update(visible=x == "Combined Search"), | |
inputs=[search_type], | |
outputs=[weight_controls] | |
) | |
search_button.click( | |
fn=handle_search, | |
inputs=[ | |
text_input, | |
image_input, | |
search_type, | |
text_weight, | |
image_weight, | |
similarity_threshold | |
], | |
outputs=[ | |
error_display, | |
gallery, | |
product_details, | |
similarity_plot | |
] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |