Raphaël Bournhonesque commited on
Commit
b6476a0
1 Parent(s): 97f2f31

first commit

Browse files
Files changed (2) hide show
  1. app.py +148 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import io
3
+ import json
4
+ import random
5
+ import re
6
+ import tempfile
7
+ from typing import Dict, List, Optional
8
+
9
+ from PIL import Image
10
+ import requests
11
+ import streamlit as st
12
+
13
+
14
+ http_session = requests.Session()
15
+
16
+ API_URL = "https://world.openfoodfacts.org/api/v0"
17
+ PRODUCT_URL = API_URL + "/product"
18
+ OFF_IMAGE_BASE_URL = "https://static.openfoodfacts.org/images/products"
19
+ BARCODE_PATH_REGEX = re.compile(r"^(...)(...)(...)(.*)$")
20
+
21
+
22
+ @st.cache()
23
+ def load_nn_data(url: str):
24
+ r = http_session.get(url)
25
+ with gzip.open(io.BytesIO(r.content), "rt") as f:
26
+ content = f.read()
27
+ return {int(key): value for key, value in json.loads(content).items()}
28
+
29
+
30
+ @st.cache()
31
+ def load_logo_data(url: str):
32
+ r = http_session.get(url)
33
+ with gzip.open(io.BytesIO(r.content), "rt") as f:
34
+ content = f.read()
35
+ return {
36
+ int(item["id"]): item for item in (json.loads(x) for x in map(str.strip, content))
37
+ }
38
+
39
+
40
+ def get_image_from_url(
41
+ image_url: str,
42
+ error_raise: bool = False,
43
+ session: Optional[requests.Session] = None,
44
+ ) -> Optional[Image.Image]:
45
+ if session:
46
+ r = http_session.get(image_url)
47
+ else:
48
+ r = requests.get(image_url)
49
+
50
+ if error_raise:
51
+ r.raise_for_status()
52
+
53
+ if r.status_code != 200:
54
+ return None
55
+
56
+ with tempfile.NamedTemporaryFile() as f:
57
+ f.write(r.content)
58
+ image = Image.open(f.name)
59
+
60
+ return image
61
+
62
+
63
+ def split_barcode(barcode: str) -> List[str]:
64
+ if not barcode.isdigit():
65
+ raise ValueError("unknown barcode format: {}".format(barcode))
66
+
67
+ match = BARCODE_PATH_REGEX.fullmatch(barcode)
68
+
69
+ if match:
70
+ return [x for x in match.groups() if x]
71
+
72
+ return [barcode]
73
+
74
+
75
+ def get_cropped_image(barcode: str, image_id: str, bounding_box):
76
+ image_path = generate_image_path(barcode, image_id)
77
+ url = OFF_IMAGE_BASE_URL + image_path
78
+ image = get_image_from_url(url, session=http_session)
79
+
80
+ if image is None:
81
+ return
82
+
83
+ ymin, xmin, ymax, xmax = bounding_box
84
+ (left, right, top, bottom) = (
85
+ xmin * image.width,
86
+ xmax * image.width,
87
+ ymin * image.height,
88
+ ymax * image.height,
89
+ )
90
+ return image.crop((left, top, right, bottom))
91
+
92
+
93
+ def generate_image_path(barcode: str, image_id: str) -> str:
94
+ splitted_barcode = split_barcode(barcode)
95
+ return "/{}/{}.jpg".format("/".join(splitted_barcode), image_id)
96
+
97
+
98
+ def display_predictions(
99
+ logo_data: Dict,
100
+ nn_data: Dict,
101
+ logo_id: Optional[int] = None,
102
+ ):
103
+ if not logo_id:
104
+ logo_id = random.choice(list(nn_data.keys()))
105
+
106
+ st.write(f"Logo ID: {logo_id}")
107
+ logo = logo_data[logo_id]
108
+ logo_nn_data = nn_data[logo_id]
109
+ nn_ids = logo_nn_data["ids"]
110
+ nn_distances = logo_nn_data["distances"]
111
+ annotation = logo_nn_data["annotation"]
112
+
113
+ cropped_image = get_cropped_image(
114
+ logo["barcode"], logo["image_id"], logo["bounding_box"]
115
+ )
116
+
117
+ if cropped_image is None:
118
+ return
119
+ st.image(cropped_image, annotation, width=200)
120
+
121
+ cropped_images: List[Image.Image] = []
122
+ captions: List[str] = []
123
+ for closest_id, distance in zip(nn_ids, nn_distances):
124
+ closest_logo = logo_data[closest_id]
125
+
126
+ cropped_image = get_cropped_image(
127
+ closest_logo["barcode"],
128
+ closest_logo["image_id"],
129
+ closest_logo["bounding_box"],
130
+ )
131
+ if cropped_image is None:
132
+ continue
133
+
134
+ if cropped_image.height > cropped_image.width:
135
+ cropped_image = cropped_image.rotate(90)
136
+
137
+ cropped_images.append(cropped_image)
138
+ captions.append(f"distance: {distance}")
139
+
140
+ if cropped_images:
141
+ st.image(cropped_images, captions, width=200)
142
+
143
+
144
+ st.sidebar.title("Logo Nearest Neighbor Demo")
145
+ # st.sidebar.write("")
146
+ nn_data = load_nn_data("https://static.openfoodfacts.org/data/logos/exact_100_neighbours.json.gz")
147
+ logo_data = load_logo_data("https://static.openfoodfacts.org/data/logos/logo_annotations.jsonl.gz")
148
+ display_predictions(logo_data=logo_data, nn_data=nn_data)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Pillow==9.3.0
2
+ requests==2.28.1
3
+ streamlit==1.15.1