Raphaël Bournhonesque commited on
Commit
5c3d937
·
1 Parent(s): 4f19c28

add approximate method

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -118,7 +118,10 @@ def display_predictions(
118
 
119
  cropped_images: List[Image.Image] = []
120
  captions: List[str] = []
121
- for closest_id, distance in zip(nn_ids, nn_distances):
 
 
 
122
  closest_logo = logo_data[closest_id]
123
 
124
  cropped_image = get_cropped_image(
@@ -143,12 +146,26 @@ st.sidebar.title("Logo Nearest Neighbors Demo")
143
  st.sidebar.write(
144
  "Get first 100 nearest neighbors for a random annotated logo.\n\n"
145
  "CLIP model is used to generate embeddings, and nearest neighbors "
146
- "are computed using a brute-force approach (no approximation)."
 
 
 
 
 
 
 
 
 
 
147
  )
148
  nn_data = load_nn_data(
149
- "https://static.openfoodfacts.org/data/logos/exact_100_neighbours.json.gz"
150
  )
151
  logo_data = load_logo_data(
152
  "https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz"
153
  )
154
- display_predictions(logo_data=logo_data, nn_data=nn_data)
 
 
 
 
 
118
 
119
  cropped_images: List[Image.Image] = []
120
  captions: List[str] = []
121
+ progress_bar = st.progress(0)
122
+
123
+ for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)):
124
+ progress_bar.progress((i + 1) / len(nn_ids))
125
  closest_logo = logo_data[closest_id]
126
 
127
  cropped_image = get_cropped_image(
 
146
  st.sidebar.write(
147
  "Get first 100 nearest neighbors for a random annotated logo.\n\n"
148
  "CLIP model is used to generate embeddings, and nearest neighbors "
149
+ "are computed either using a brute-force approach or with ANN."
150
+ )
151
+ logo_id = st.sidebar.number_input("logo ID", step=1) or None
152
+ approximate = (
153
+ st.sidebar.checkbox(
154
+ "ANN (HNSW)",
155
+ value=False,
156
+ help="Display approximate neighbors (instead of real "
157
+ "neighbors computed using brute-force approach",
158
+ )
159
+ or None
160
  )
161
  nn_data = load_nn_data(
162
+ f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz"
163
  )
164
  logo_data = load_logo_data(
165
  "https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz"
166
  )
167
+ if approximate:
168
+ st.write("Using approximate nearest neighbors method")
169
+ else:
170
+ st.write("Using exact (brute-force) nearest neighbors method")
171
+ display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id)