import gzip import io import json import random import re import tempfile from typing import Dict, List, Optional from PIL import Image import requests import streamlit as st http_session = requests.Session() API_URL = "https://world.openfoodfacts.org/api/v0" PRODUCT_URL = API_URL + "/product" OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products" BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$") @st.cache(allow_output_mutation=True) def load_nn_data(url: str): r = http_session.get(url) with gzip.open(io.BytesIO(r.content), "rt") as f: return {int(key): value for key, value in json.loads(f.read()).items()} @st.cache(allow_output_mutation=True) def load_logo_data(url: str): r = http_session.get(url) with gzip.open(io.BytesIO(r.content), "rt") as f: return { int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, f)) } def get_image_from_url( image_url: str, error_raise: bool = False, session: Optional[requests.Session] = None, ) -> Optional[Image.Image]: if session: r = http_session.get(image_url) else: r = requests.get(image_url) if error_raise: r.raise_for_status() if r.status_code != 200: return None with tempfile.NamedTemporaryFile() as f: f.write(r.content) image = Image.open(f.name) return image def split_barcode(barcode: str) -> List[str]: if not barcode.isdigit(): raise ValueError("unknown barcode format: {}".format(barcode)) match = BARCODE_PATH_REGEX.fullmatch(barcode) if match: return [x for x in match.groups() if x] return [barcode] def get_cropped_image(barcode: str, image_id: str, bounding_box): image_path = generate_image_path(barcode, image_id) url = OFF_IMAGE_BASE_URL + image_path image = get_image_from_url(url, session=http_session) if image is None: return ymin, xmin, ymax, xmax = bounding_box (left, right, top, bottom) = ( xmin * image.width, xmax * image.width, ymin * image.height, ymax * image.height, ) return image.crop((left, top, right, bottom)) def generate_image_path(barcode: str, image_id: str) -> str: splitted_barcode = split_barcode(barcode) return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id) def display_predictions( logo_data: Dict, nn_data: Dict, logo_id: Optional[int] = None, ): if not logo_id: logo_id = random.choice(list(nn_data.keys())) st.write(f"Logo ID: {logo_id}") logo = logo_data[logo_id] logo_nn_data = nn_data[logo_id] nn_ids = logo_nn_data["ids"] nn_distances = logo_nn_data["distances"] annotation = logo_nn_data["annotation"] cropped_image = get_cropped_image( logo["barcode"], logo["image_id"], logo["bounding_box"] ) if cropped_image is None: return st.image(cropped_image, annotation, width=200) cropped_images: List[Image.Image] = [] captions: List[str] = [] progress_bar = st.progress(0) for i, (closest_id, distance) in enumerate(zip(nn_ids, nn_distances)): progress_bar.progress((i + 1) / len(nn_ids)) closest_logo = logo_data[closest_id] cropped_image = get_cropped_image( closest_logo["barcode"], closest_logo["image_id"], closest_logo["bounding_box"], ) if cropped_image is None: continue if cropped_image.height > cropped_image.width: cropped_image = cropped_image.rotate(90) cropped_images.append(cropped_image) captions.append(f"distance: {distance}") if cropped_images: st.image(cropped_images, captions, width=200) st.sidebar.title("Logo Nearest Neighbors Demo") st.sidebar.write( "Get first 100 nearest neighbors for a random annotated logo.\n\n" "CLIP model is used to generate embeddings, and nearest neighbors " "are computed either using a brute-force approach or with ANN." ) logo_id = st.sidebar.number_input("logo ID", step=1) or None approximate = ( st.sidebar.checkbox( "ANN (HNSW)", value=False, help="Display approximate neighbors (instead of real " "neighbors computed using brute-force approach", ) or None ) nn_data = load_nn_data( f"https://static.openfoodfacts.org/data/logos/{'hnsw_50_closest_neighbours' if approximate else 'exact_100_neighbours'}.json.gz" ) logo_data = load_logo_data( "https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz" ) if approximate: st.write("Using approximate nearest neighbors method") else: st.write("Using exact (brute-force) nearest neighbors method") display_predictions(logo_data=logo_data, nn_data=nn_data, logo_id=logo_id)