import gradio as gr import numpy as np from PIL import Image as PILImage import torch from torchvision import transforms import matplotlib.pyplot as plt import pandas as pd from io import BytesIO import requests from beartype.typing import Any, Hashable from requests import RequestException from superlinked import framework as sl from datasets import load_dataset # Constants DATASET_ID = "tomytjandra/h-and-m-fashion-caption" VIT_MODEL_ID = "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K" NUM_SAMPLES = 200 SEED = 42 LIMIT = 3 # Load and prepare dataset fashion_dataset = load_dataset(DATASET_ID) fashion_sample_dataset = fashion_dataset["train"].shuffle(seed=SEED).select(range(NUM_SAMPLES)) # Organize metadata fashion_json_data = [item for i, item in enumerate(fashion_sample_dataset)] for i, item in enumerate(fashion_json_data): fashion_json_data[i]["id"] = i fashion_df = pd.DataFrame(fashion_json_data) fashion_df["description"] = fashion_df["text"] json_data = fashion_df.to_dict(orient="records") # Superlinked setup class Image(sl.Schema): id: sl.IdField image: sl.Blob description: sl.String image = Image() image_embedding_space = sl.ImageSpace(image=image.image, model=VIT_MODEL_ID, model_handler=sl.ModelHandler.OPEN_CLIP) description_space = sl.TextSimilaritySpace(text=image.description, model="Alibaba-NLP/gte-large-en-v1.5") composite_index = sl.Index([image_embedding_space, description_space]) source = sl.InMemorySource(image) executor = sl.InMemoryExecutor(sources=[source], indices=[composite_index]) app = executor.run() source.put(json_data) # Query construction combined_query = ( sl.Query( composite_index, weights={ description_space: sl.Param("description_weight"), image_embedding_space: sl.Param("image_embedding_weight"), }, ) .find(image) .similar(description_space, sl.Param("text_search")) .similar(image_embedding_space.image, sl.Param("image_search")) .similar(image_embedding_space.description, sl.Param("text_in_image_search")) .select_all() .limit(3) ) def process_search_results(results_df, dataset, similarity_threshold=0.5): """ Process search results with filtering and enhanced descriptions """ filtered_df = results_df[results_df['similarity_score'] >= similarity_threshold] if filtered_df.empty: return { "images": [], "descriptions": [], "similarity_plot": None, "error": "No results meet the similarity threshold" } images = [] descriptions = [] scores = [] for _, row in filtered_df.iterrows(): product_id = int(row['id']) img = dataset["image"][product_id] if isinstance(img, np.ndarray): img = PILImage.fromarray(img) product_info = dataset[product_id] description = { "Product ID": str(product_id), "Description": product_info.get("text", "N/A"), "Category": product_info.get("category", "N/A"), "Similarity Score": f"{float(row['similarity_score']):.3f}", "Price Range": product_info.get("price_range", "N/A"), "Colors": product_info.get("colors", []), "Brand": product_info.get("brand", "N/A") } images.append(img) descriptions.append(description) scores.append(float(row['similarity_score'])) similarity_plot = create_similarity_visualization(scores, descriptions) return { "images": images, "descriptions": descriptions, "similarity_plot": similarity_plot, "error": None } def create_similarity_visualization(scores, descriptions): """ Create an enhanced visualization of similarity scores """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[2, 1]) bars = ax1.bar(range(len(scores)), scores) ax1.set_title('Similarity Scores Distribution') ax1.set_xlabel('Result Index') ax1.set_ylabel('Similarity Score') for bar in bars: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom') ax1.axhline(y=0.5, color='r', linestyle='--', label='Threshold') ax1.legend() categories = [d.get("Category") for d in descriptions] unique_categories = list(set(categories)) category_counts = [categories.count(cat) for cat in unique_categories] ax2.pie(category_counts, labels=unique_categories, autopct='%1.1f%%') ax2.set_title('Category Distribution') plt.tight_layout() return fig def search_products(search_text, search_image, search_type, weight_text=1.0, weight_image=1.0, similarity_threshold=0.5): """ Enhanced search function with filtering and error handling """ try: if search_type == "Text Only" and not search_text: raise ValueError("Please enter search text") if search_type == "Image Only" and search_image is None: raise ValueError("Please upload an image") if search_type == "Text Only": results = app.query( combined_query, description_weight=1, text_search=search_text ) elif search_type == "Image Only": if isinstance(search_image, np.ndarray): search_image = PILImage.fromarray(search_image) results = app.query( combined_query, image_embedding_weight=1, image_search=search_image ) else: # Combined Search if isinstance(search_image, np.ndarray): search_image = PILImage.fromarray(search_image) results = app.query( combined_query, description_weight=weight_text, image_embedding_weight=weight_image, text_search=search_text, image_search=search_image ) results_df = results.to_pandas() return process_search_results(results_df, fashion_sample_dataset, similarity_threshold) except Exception as e: return { "images": [], "descriptions": [], "similarity_plot": None, "error": str(e) } def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Fashion Product Semantic Search") with gr.Row(): with gr.Column(scale=1): text_input = gr.Textbox( label="Search Text", placeholder="Enter product description..." ) image_input = gr.Image( label="Search Image", type="pil" ) search_type = gr.Radio( choices=["Text Only", "Image Only", "Combined Search"], label="Search Type", value="Text Only" ) with gr.Accordion("Advanced Settings", open=False): similarity_threshold = gr.Slider( minimum=0, maximum=1, value=0.5, label="Similarity Threshold" ) with gr.Row(visible=False) as weight_controls: text_weight = gr.Slider( minimum=0, maximum=2, value=1, label="Text Weight" ) image_weight = gr.Slider( minimum=0, maximum=2, value=1, label="Image Weight" ) search_button = gr.Button("Search", variant="primary") with gr.Row(): with gr.Column(scale=2): gallery = gr.Gallery( label="Search Results", columns=3, height="400px" ) product_details = gr.JSON( label="Product Details" ) similarity_plot = gr.Plot( label="Similarity Analysis" ) error_display = gr.Textbox( label="Status", visible=False ) def handle_search(text, image, search_type, text_weight, image_weight, threshold): results = search_products(text, image, search_type, text_weight, image_weight, threshold) if results["error"]: return { error_display: gr.update(value=results["error"], visible=True), gallery: None, product_details: None, similarity_plot: None } return { error_display: gr.update(visible=False), gallery: results["images"], product_details: results["descriptions"] if results["descriptions"] else None, similarity_plot: results["similarity_plot"] } search_type.change( fn=lambda x: gr.Row.update(visible=x == "Combined Search"), inputs=[search_type], outputs=[weight_controls] ) search_button.click( fn=handle_search, inputs=[ text_input, image_input, search_type, text_weight, image_weight, similarity_threshold ], outputs=[ error_display, gallery, product_details, similarity_plot ] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()