divyanshujainlivein commited on
Commit
ab0b995
Β·
verified Β·
1 Parent(s): bd329fa

Upload requirements.txt

Browse files
Files changed (1) hide show
  1. requirements.txt +365 -0
requirements.txt ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Here are the contents for the requirements.txt file:
2
+
3
+ ```
4
+ datasets
5
+ superlinked==18.3.0
6
+ google-generativeai
7
+ gradio
8
+ pillow
9
+ torch
10
+ torchvision
11
+ matplotlib
12
+ pandas
13
+ beartype
14
+ requests
15
+ ```
16
+
17
+ Now for the main app.py file:
18
+
19
+ ```python
20
+ import gradio as gr
21
+ import numpy as np
22
+ from PIL import Image as PILImage
23
+ import torch
24
+ from torchvision import transforms
25
+ import matplotlib.pyplot as plt
26
+ import pandas as pd
27
+ from io import BytesIO
28
+ import requests
29
+ from beartype.typing import Any, Hashable
30
+ from requests import RequestException
31
+ from superlinked import framework as sl
32
+ from datasets import load_dataset
33
+
34
+ # Constants
35
+ DATASET_ID = "tomytjandra/h-and-m-fashion-caption"
36
+ VIT_MODEL_ID = "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
37
+ NUM_SAMPLES = 1000
38
+ SEED = 42
39
+ LIMIT = 3
40
+
41
+ # Load and prepare dataset
42
+ fashion_dataset = load_dataset(DATASET_ID)
43
+ fashion_sample_dataset = fashion_dataset["train"].shuffle(seed=SEED).select(range(NUM_SAMPLES))
44
+
45
+ # Organize metadata
46
+ fashion_json_data = [item for i, item in enumerate(fashion_sample_dataset)]
47
+ for i, item in enumerate(fashion_json_data):
48
+ fashion_json_data[i]["id"] = i
49
+ fashion_df = pd.DataFrame(fashion_json_data)
50
+ fashion_df["description"] = fashion_df["text"]
51
+ json_data = fashion_df.to_dict(orient="records")
52
+
53
+ # Superlinked setup
54
+ class Image(sl.Schema):
55
+ id: sl.IdField
56
+ image: sl.Blob
57
+ description: sl.String
58
+
59
+ image = Image()
60
+ image_embedding_space = sl.ImageSpace(image=image.image, model=VIT_MODEL_ID, model_handler=sl.ModelHandler.OPEN_CLIP)
61
+ description_space = sl.TextSimilaritySpace(text=image.description, model="Alibaba-NLP/gte-large-en-v1.5")
62
+ composite_index = sl.Index([image_embedding_space, description_space])
63
+ source = sl.InMemorySource(image)
64
+ executor = sl.InMemoryExecutor(sources=[source], indices=[composite_index])
65
+ app = executor.run()
66
+ source.put(json_data)
67
+
68
+ # Query construction
69
+ combined_query = (
70
+ sl.Query(
71
+ composite_index,
72
+ weights={
73
+ description_space: sl.Param("description_weight"),
74
+ image_embedding_space: sl.Param("image_embedding_weight"),
75
+ },
76
+ )
77
+ .find(image)
78
+ .similar(description_space, sl.Param("text_search"))
79
+ .similar(image_embedding_space.image, sl.Param("image_search"))
80
+ .similar(image_embedding_space.description, sl.Param("text_in_image_search"))
81
+ .select_all()
82
+ .limit(3)
83
+ )
84
+
85
+ def process_search_results(results_df, dataset, similarity_threshold=0.5):
86
+ """
87
+ Process search results with filtering and enhanced descriptions
88
+ """
89
+ filtered_df = results_df[results_df['similarity_score'] >= similarity_threshold]
90
+
91
+ if filtered_df.empty:
92
+ return {
93
+ "images": [],
94
+ "descriptions": [],
95
+ "similarity_plot": None,
96
+ "error": "No results meet the similarity threshold"
97
+ }
98
+
99
+ images = []
100
+ descriptions = []
101
+ scores = []
102
+
103
+ for _, row in filtered_df.iterrows():
104
+ product_id = int(row['id'])
105
+ img = dataset["image"][product_id]
106
+ if isinstance(img, np.ndarray):
107
+ img = PILImage.fromarray(img)
108
+
109
+ product_info = dataset[product_id]
110
+ description = {
111
+ "Product ID": str(product_id),
112
+ "Description": product_info.get("text", "N/A"),
113
+ "Category": product_info.get("category", "N/A"),
114
+ "Similarity Score": f"{float(row['similarity_score']):.3f}",
115
+ "Price Range": product_info.get("price_range", "N/A"),
116
+ "Colors": product_info.get("colors", []),
117
+ "Brand": product_info.get("brand", "N/A")
118
+ }
119
+
120
+ images.append(img)
121
+ descriptions.append(description)
122
+ scores.append(float(row['similarity_score']))
123
+
124
+ similarity_plot = create_similarity_visualization(scores, descriptions)
125
+
126
+ return {
127
+ "images": images,
128
+ "descriptions": descriptions,
129
+ "similarity_plot": similarity_plot,
130
+ "error": None
131
+ }
132
+
133
+ def create_similarity_visualization(scores, descriptions):
134
+ """
135
+ Create an enhanced visualization of similarity scores
136
+ """
137
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), height_ratios=[2, 1])
138
+
139
+ bars = ax1.bar(range(len(scores)), scores)
140
+ ax1.set_title('Similarity Scores Distribution')
141
+ ax1.set_xlabel('Result Index')
142
+ ax1.set_ylabel('Similarity Score')
143
+
144
+ for bar in bars:
145
+ height = bar.get_height()
146
+ ax1.text(bar.get_x() + bar.get_width()/2., height,
147
+ f'{height:.3f}',
148
+ ha='center', va='bottom')
149
+
150
+ ax1.axhline(y=0.5, color='r', linestyle='--', label='Threshold')
151
+ ax1.legend()
152
+
153
+ categories = [d.get("Category") for d in descriptions]
154
+ unique_categories = list(set(categories))
155
+ category_counts = [categories.count(cat) for cat in unique_categories]
156
+
157
+ ax2.pie(category_counts, labels=unique_categories, autopct='%1.1f%%')
158
+ ax2.set_title('Category Distribution')
159
+
160
+ plt.tight_layout()
161
+ return fig
162
+
163
+ def search_products(search_text, search_image, search_type, weight_text=1.0, weight_image=1.0, similarity_threshold=0.5):
164
+ """
165
+ Enhanced search function with filtering and error handling
166
+ """
167
+ try:
168
+ if search_type == "Text Only" and not search_text:
169
+ raise ValueError("Please enter search text")
170
+ if search_type == "Image Only" and search_image is None:
171
+ raise ValueError("Please upload an image")
172
+
173
+ if search_type == "Text Only":
174
+ results = app.query(
175
+ combined_query,
176
+ description_weight=1,
177
+ text_search=search_text
178
+ )
179
+ elif search_type == "Image Only":
180
+ if isinstance(search_image, np.ndarray):
181
+ search_image = PILImage.fromarray(search_image)
182
+
183
+ results = app.query(
184
+ combined_query,
185
+ image_embedding_weight=1,
186
+ image_search=search_image
187
+ )
188
+ else: # Combined Search
189
+ if isinstance(search_image, np.ndarray):
190
+ search_image = PILImage.fromarray(search_image)
191
+
192
+ results = app.query(
193
+ combined_query,
194
+ description_weight=weight_text,
195
+ image_embedding_weight=weight_image,
196
+ text_search=search_text,
197
+ image_search=search_image
198
+ )
199
+
200
+ results_df = results.to_pandas()
201
+ return process_search_results(results_df, fashion_sample_dataset, similarity_threshold)
202
+
203
+ except Exception as e:
204
+ return {
205
+ "images": [],
206
+ "descriptions": [],
207
+ "similarity_plot": None,
208
+ "error": str(e)
209
+ }
210
+
211
+ def create_interface():
212
+ with gr.Blocks() as demo:
213
+ gr.Markdown("# Fashion Product Semantic Search")
214
+
215
+ with gr.Row():
216
+ with gr.Column(scale=1):
217
+ text_input = gr.Textbox(
218
+ label="Search Text",
219
+ placeholder="Enter product description..."
220
+ )
221
+ image_input = gr.Image(
222
+ label="Search Image",
223
+ type="pil"
224
+ )
225
+ search_type = gr.Radio(
226
+ choices=["Text Only", "Image Only", "Combined Search"],
227
+ label="Search Type",
228
+ value="Text Only"
229
+ )
230
+
231
+ with gr.Accordion("Advanced Settings", open=False):
232
+ similarity_threshold = gr.Slider(
233
+ minimum=0, maximum=1, value=0.5,
234
+ label="Similarity Threshold"
235
+ )
236
+ with gr.Row(visible=False) as weight_controls:
237
+ text_weight = gr.Slider(
238
+ minimum=0, maximum=2, value=1,
239
+ label="Text Weight"
240
+ )
241
+ image_weight = gr.Slider(
242
+ minimum=0, maximum=2, value=1,
243
+ label="Image Weight"
244
+ )
245
+
246
+ search_button = gr.Button("Search", variant="primary")
247
+
248
+ with gr.Row():
249
+ with gr.Column(scale=2):
250
+ gallery = gr.Gallery(
251
+ label="Search Results",
252
+ columns=3,
253
+ height="400px"
254
+ )
255
+ product_details = gr.JSON(
256
+ label="Product Details"
257
+ )
258
+ similarity_plot = gr.Plot(
259
+ label="Similarity Analysis"
260
+ )
261
+
262
+ error_display = gr.Textbox(
263
+ label="Status",
264
+ visible=False
265
+ )
266
+
267
+ def handle_search(text, image, search_type, text_weight, image_weight, threshold):
268
+ results = search_products(text, image, search_type, text_weight, image_weight, threshold)
269
+
270
+ if results["error"]:
271
+ return {
272
+ error_display: gr.update(value=results["error"], visible=True),
273
+ gallery: None,
274
+ product_details: None,
275
+ similarity_plot: None
276
+ }
277
+
278
+ return {
279
+ error_display: gr.update(visible=False),
280
+ gallery: results["images"],
281
+ product_details: results["descriptions"] if results["descriptions"] else None,
282
+ similarity_plot: results["similarity_plot"]
283
+ }
284
+
285
+ search_type.change(
286
+ fn=lambda x: gr.Row.update(visible=x == "Combined Search"),
287
+ inputs=[search_type],
288
+ outputs=[weight_controls]
289
+ )
290
+
291
+ search_button.click(
292
+ fn=handle_search,
293
+ inputs=[
294
+ text_input,
295
+ image_input,
296
+ search_type,
297
+ text_weight,
298
+ image_weight,
299
+ similarity_threshold
300
+ ],
301
+ outputs=[
302
+ error_display,
303
+ gallery,
304
+ product_details,
305
+ similarity_plot
306
+ ]
307
+ )
308
+
309
+ return demo
310
+
311
+ if __name__ == "__main__":
312
+ demo = create_interface()
313
+ demo.launch()
314
+ ```
315
+
316
+ For the README.md file:
317
+
318
+ ```markdown
319
+ # Fashion Product Semantic Search
320
+
321
+ This is a Gradio application for semantic search of fashion products using both text and image inputs.
322
+
323
+ ## Features
324
+
325
+ - Text-based search
326
+ - Image-based search
327
+ - Combined text and image search
328
+ - Adjustable similarity thresholds
329
+ - Detailed product information display
330
+ - Similarity score visualization
331
+
332
+ ## Requirements
333
+
334
+ See requirements.txt for detailed dependencies.
335
+
336
+ ## Usage
337
+
338
+ The application provides a web interface where users can:
339
+ 1. Enter text descriptions
340
+ 2. Upload images
341
+ 3. Choose search type (Text Only, Image Only, or Combined Search)
342
+ 4. Adjust advanced settings
343
+ 5. View search results with detailed product information
344
+ ```
345
+
346
+ And finally, for the .gitignore file:
347
+
348
+ ```
349
+ __pycache__/
350
+ *.pyc
351
+ .ipynb_checkpoints/
352
+ .DS_Store
353
+ ```
354
+
355
+ These files should be organized in your project directory as follows:
356
+
357
+ ```
358
+ project_directory/
359
+ β”œβ”€β”€ app.py
360
+ β”œβ”€β”€ requirements.txt
361
+ β”œβ”€β”€ README.md
362
+ └── .gitignore
363
+ ```
364
+
365
+ You can now use these files to deploy your application on Hugging Face Spaces or run it locally.