Spaces:
Runtime error
Runtime error
Raphaël Bournhonesque
commited on
Commit
•
b6476a0
1
Parent(s):
97f2f31
first commit
Browse files- app.py +148 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
from typing import Dict, List, Optional
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import requests
|
11 |
+
import streamlit as st
|
12 |
+
|
13 |
+
|
14 |
+
http_session = requests.Session()
|
15 |
+
|
16 |
+
API_URL = "https://world.openfoodfacts.org/api/v0"
|
17 |
+
PRODUCT_URL = API_URL + "/product"
|
18 |
+
OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products"
|
19 |
+
BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$")
|
20 |
+
|
21 |
+
|
22 |
+
@st.cache()
|
23 |
+
def load_nn_data(url: str):
|
24 |
+
r = http_session.get(url)
|
25 |
+
with gzip.open(io.BytesIO(r.content), "rt") as f:
|
26 |
+
content = f.read()
|
27 |
+
return {int(key): value for key, value in json.loads(content).items()}
|
28 |
+
|
29 |
+
|
30 |
+
@st.cache()
|
31 |
+
def load_logo_data(url: str):
|
32 |
+
r = http_session.get(url)
|
33 |
+
with gzip.open(io.BytesIO(r.content), "rt") as f:
|
34 |
+
content = f.read()
|
35 |
+
return {
|
36 |
+
int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, content))
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def get_image_from_url(
|
41 |
+
image_url: str,
|
42 |
+
error_raise: bool = False,
|
43 |
+
session: Optional[requests.Session] = None,
|
44 |
+
) -> Optional[Image.Image]:
|
45 |
+
if session:
|
46 |
+
r = http_session.get(image_url)
|
47 |
+
else:
|
48 |
+
r = requests.get(image_url)
|
49 |
+
|
50 |
+
if error_raise:
|
51 |
+
r.raise_for_status()
|
52 |
+
|
53 |
+
if r.status_code != 200:
|
54 |
+
return None
|
55 |
+
|
56 |
+
with tempfile.NamedTemporaryFile() as f:
|
57 |
+
f.write(r.content)
|
58 |
+
image = Image.open(f.name)
|
59 |
+
|
60 |
+
return image
|
61 |
+
|
62 |
+
|
63 |
+
def split_barcode(barcode: str) -> List[str]:
|
64 |
+
if not barcode.isdigit():
|
65 |
+
raise ValueError("unknown barcode format: {}".format(barcode))
|
66 |
+
|
67 |
+
match = BARCODE_PATH_REGEX.fullmatch(barcode)
|
68 |
+
|
69 |
+
if match:
|
70 |
+
return [x for x in match.groups() if x]
|
71 |
+
|
72 |
+
return [barcode]
|
73 |
+
|
74 |
+
|
75 |
+
def get_cropped_image(barcode: str, image_id: str, bounding_box):
|
76 |
+
image_path = generate_image_path(barcode, image_id)
|
77 |
+
url = OFF_IMAGE_BASE_URL + image_path
|
78 |
+
image = get_image_from_url(url, session=http_session)
|
79 |
+
|
80 |
+
if image is None:
|
81 |
+
return
|
82 |
+
|
83 |
+
ymin, xmin, ymax, xmax = bounding_box
|
84 |
+
(left, right, top, bottom) = (
|
85 |
+
xmin * image.width,
|
86 |
+
xmax * image.width,
|
87 |
+
ymin * image.height,
|
88 |
+
ymax * image.height,
|
89 |
+
)
|
90 |
+
return image.crop((left, top, right, bottom))
|
91 |
+
|
92 |
+
|
93 |
+
def generate_image_path(barcode: str, image_id: str) -> str:
|
94 |
+
splitted_barcode = split_barcode(barcode)
|
95 |
+
return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id)
|
96 |
+
|
97 |
+
|
98 |
+
def display_predictions(
|
99 |
+
logo_data: Dict,
|
100 |
+
nn_data: Dict,
|
101 |
+
logo_id: Optional[int] = None,
|
102 |
+
):
|
103 |
+
if not logo_id:
|
104 |
+
logo_id = random.choice(list(nn_data.keys()))
|
105 |
+
|
106 |
+
st.write(f"Logo ID: {logo_id}")
|
107 |
+
logo = logo_data[logo_id]
|
108 |
+
logo_nn_data = nn_data[logo_id]
|
109 |
+
nn_ids = logo_nn_data["ids"]
|
110 |
+
nn_distances = logo_nn_data["distances"]
|
111 |
+
annotation = logo_nn_data["annotation"]
|
112 |
+
|
113 |
+
cropped_image = get_cropped_image(
|
114 |
+
logo["barcode"], logo["image_id"], logo["bounding_box"]
|
115 |
+
)
|
116 |
+
|
117 |
+
if cropped_image is None:
|
118 |
+
return
|
119 |
+
st.image(cropped_image, annotation, width=200)
|
120 |
+
|
121 |
+
cropped_images: List[Image.Image] = []
|
122 |
+
captions: List[str] = []
|
123 |
+
for closest_id, distance in zip(nn_ids, nn_distances):
|
124 |
+
closest_logo = logo_data[closest_id]
|
125 |
+
|
126 |
+
cropped_image = get_cropped_image(
|
127 |
+
closest_logo["barcode"],
|
128 |
+
closest_logo["image_id"],
|
129 |
+
closest_logo["bounding_box"],
|
130 |
+
)
|
131 |
+
if cropped_image is None:
|
132 |
+
continue
|
133 |
+
|
134 |
+
if cropped_image.height > cropped_image.width:
|
135 |
+
cropped_image = cropped_image.rotate(90)
|
136 |
+
|
137 |
+
cropped_images.append(cropped_image)
|
138 |
+
captions.append(f"distance: {distance}")
|
139 |
+
|
140 |
+
if cropped_images:
|
141 |
+
st.image(cropped_images, captions, width=200)
|
142 |
+
|
143 |
+
|
144 |
+
st.sidebar.title("Logo Nearest Neighbor Demo")
|
145 |
+
# st.sidebar.write("")
|
146 |
+
nn_data = load_nn_data("https://static.openfoodfacts.org/data/logos/exact_100_neighbours.json.gz")
|
147 |
+
logo_data = load_logo_data("https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz")
|
148 |
+
display_predictions(logo_data=logo_data, nn_data=nn_data)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Pillow==9.3.0
|
2 |
+
requests==2.28.1
|
3 |
+
streamlit==1.15.1
|