fashionsearch / app.py
divyanshujainlivein's picture
Update app.py
5588e4a verified
raw
history blame contribute delete
9.82 kB
import gradio as gr
import numpy as np
from PIL import Image as PILImage
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
from io import BytesIO
import requests
from beartype.typing import Any, Hashable
from requests import RequestException
from superlinked import framework as sl
from datasets import load_dataset
# Constants
DATASET_ID = "tomytjandra/h-and-m-fashion-caption"
VIT_MODEL_ID = "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
NUM_SAMPLES = 200
SEED = 42
LIMIT = 3
# Load and prepare dataset
fashion_dataset = load_dataset(DATASET_ID)
fashion_sample_dataset = fashion_dataset["train"].shuffle(seed=SEED).select(range(NUM_SAMPLES))
# Organize metadata
fashion_json_data = [item for i, item in enumerate(fashion_sample_dataset)]
for i, item in enumerate(fashion_json_data):
fashion_json_data[i]["id"] = i
fashion_df = pd.DataFrame(fashion_json_data)
fashion_df["description"] = fashion_df["text"]
json_data = fashion_df.to_dict(orient="records")
# Superlinked setup
class Image(sl.Schema):
id: sl.IdField
image: sl.Blob
description: sl.String
image = Image()
image_embedding_space = sl.ImageSpace(image=image.image, model=VIT_MODEL_ID, model_handler=sl.ModelHandler.OPEN_CLIP)
description_space = sl.TextSimilaritySpace(text=image.description, model="Alibaba-NLP/gte-large-en-v1.5")
composite_index = sl.Index([image_embedding_space, description_space])
source = sl.InMemorySource(image)
executor = sl.InMemoryExecutor(sources=[source], indices=[composite_index])
app = executor.run()
source.put(json_data)
# Query construction
combined_query = (
sl.Query(
composite_index,
weights={
description_space: sl.Param("description_weight"),
image_embedding_space: sl.Param("image_embedding_weight"),
},
)
.find(image)
.similar(description_space, sl.Param("text_search"))
.similar(image_embedding_space.image, sl.Param("image_search"))
.similar(image_embedding_space.description, sl.Param("text_in_image_search"))
.select_all()
.limit(3)
)
def process_search_results(results_df, dataset, similarity_threshold=0.5):
"""
Process search results with filtering and enhanced descriptions
"""
filtered_df = results_df[results_df['similarity_score'] >= similarity_threshold]
if filtered_df.empty:
return {
"images": [],
"descriptions": [],
"similarity_plot": None,
"error": "No results meet the similarity threshold"
}
images = []
descriptions = []
scores = []
for _, row in filtered_df.iterrows():
product_id = int(row['id'])
img = dataset["image"][product_id]
if isinstance(img, np.ndarray):
img = PILImage.fromarray(img)
product_info = dataset[product_id]
description = {
"Product ID": str(product_id),
"Description": product_info.get("text", "N/A"),
"Category": product_info.get("category", "N/A"),
"Similarity Score": f"{float(row['similarity_score']):.3f}",
"Price Range": product_info.get("price_range", "N/A"),
"Colors": product_info.get("colors", []),
"Brand": product_info.get("brand", "N/A")
}
images.append(img)
descriptions.append(description)
scores.append(float(row['similarity_score']))
similarity_plot = create_similarity_visualization(scores, descriptions)
return {
"images": images,
"descriptions": descriptions,
"similarity_plot": similarity_plot,
"error": None
}
def create_similarity_visualization(scores, descriptions):
"""
Create an enhanced visualization of similarity scores
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[2, 1])
bars = ax1.bar(range(len(scores)), scores)
ax1.set_title('Similarity Scores Distribution')
ax1.set_xlabel('Result Index')
ax1.set_ylabel('Similarity Score')
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom')
ax1.axhline(y=0.5, color='r', linestyle='--', label='Threshold')
ax1.legend()
categories = [d.get("Category") for d in descriptions]
unique_categories = list(set(categories))
category_counts = [categories.count(cat) for cat in unique_categories]
ax2.pie(category_counts, labels=unique_categories, autopct='%1.1f%%')
ax2.set_title('Category Distribution')
plt.tight_layout()
return fig
def search_products(search_text, search_image, search_type, weight_text=1.0, weight_image=1.0, similarity_threshold=0.5):
"""
Enhanced search function with filtering and error handling
"""
try:
if search_type == "Text Only" and not search_text:
raise ValueError("Please enter search text")
if search_type == "Image Only" and search_image is None:
raise ValueError("Please upload an image")
if search_type == "Text Only":
results = app.query(
combined_query,
description_weight=1,
text_search=search_text
)
elif search_type == "Image Only":
if isinstance(search_image, np.ndarray):
search_image = PILImage.fromarray(search_image)
results = app.query(
combined_query,
image_embedding_weight=1,
image_search=search_image
)
else: # Combined Search
if isinstance(search_image, np.ndarray):
search_image = PILImage.fromarray(search_image)
results = app.query(
combined_query,
description_weight=weight_text,
image_embedding_weight=weight_image,
text_search=search_text,
image_search=search_image
)
results_df = results.to_pandas()
return process_search_results(results_df, fashion_sample_dataset, similarity_threshold)
except Exception as e:
return {
"images": [],
"descriptions": [],
"similarity_plot": None,
"error": str(e)
}
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Fashion Product Semantic Search")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Search Text",
placeholder="Enter product description..."
)
image_input = gr.Image(
label="Search Image",
type="pil"
)
search_type = gr.Radio(
choices=["Text Only", "Image Only", "Combined Search"],
label="Search Type",
value="Text Only"
)
with gr.Accordion("Advanced Settings", open=False):
similarity_threshold = gr.Slider(
minimum=0, maximum=1, value=0.5,
label="Similarity Threshold"
)
with gr.Row(visible=False) as weight_controls:
text_weight = gr.Slider(
minimum=0, maximum=2, value=1,
label="Text Weight"
)
image_weight = gr.Slider(
minimum=0, maximum=2, value=1,
label="Image Weight"
)
search_button = gr.Button("Search", variant="primary")
with gr.Row():
with gr.Column(scale=2):
gallery = gr.Gallery(
label="Search Results",
columns=3,
height="400px"
)
product_details = gr.JSON(
label="Product Details"
)
similarity_plot = gr.Plot(
label="Similarity Analysis"
)
error_display = gr.Textbox(
label="Status",
visible=False
)
def handle_search(text, image, search_type, text_weight, image_weight, threshold):
results = search_products(text, image, search_type, text_weight, image_weight, threshold)
if results["error"]:
return {
error_display: gr.update(value=results["error"], visible=True),
gallery: None,
product_details: None,
similarity_plot: None
}
return {
error_display: gr.update(visible=False),
gallery: results["images"],
product_details: results["descriptions"] if results["descriptions"] else None,
similarity_plot: results["similarity_plot"]
}
search_type.change(
fn=lambda x: gr.Row.update(visible=x == "Combined Search"),
inputs=[search_type],
outputs=[weight_controls]
)
search_button.click(
fn=handle_search,
inputs=[
text_input,
image_input,
search_type,
text_weight,
image_weight,
similarity_threshold
],
outputs=[
error_display,
gallery,
product_details,
similarity_plot
]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()