grimbano commited on
Commit
da50ab5
·
1 Parent(s): d151d1e

feat(init): :tada: Initial commit pushing up app data

Browse files
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
1
+ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
embeddings/pokemon_embeddings_pkmn.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd76097a052a82770458398d85021a58e2d511e916be96219099070a2f4af247
3
+ size 7265971
requirements.txt CHANGED
@@ -1,3 +1,68 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ attrs==25.3.0
5
+ blinker==1.9.0
6
+ cachetools==5.5.2
7
+ certifi==2025.4.26
8
+ charset-normalizer==3.4.2
9
+ click==8.2.1
10
+ colorama==0.4.6
11
+ fastapi==0.109.2
12
+ filelock==3.18.0
13
+ fsspec==2025.5.1
14
+ gitdb==4.0.12
15
+ gitpython==3.1.44
16
+ h11==0.16.0
17
+ huggingface-hub==0.32.2
18
+ idna==3.10
19
+ jinja2==3.1.6
20
+ jsonschema==4.24.0
21
+ jsonschema-specifications==2025.4.1
22
+ markdown-it-py==3.0.0
23
+ markupsafe==3.0.2
24
+ mdurl==0.1.2
25
+ mpmath==1.3.0
26
+ narwhals==1.41.0
27
+ networkx==3.4.2
28
+ numpy==2.2.6
29
+ packaging==23.2
30
+ pandas==2.2.3
31
+ pillow==11.2.1
32
+ protobuf==4.25.8
33
+ pyarrow==20.0.0
34
+ pydantic==2.11.5
35
+ pydantic-core==2.33.2
36
+ pydeck==0.9.1
37
+ pygments==2.19.1
38
+ python-dateutil==2.9.0.post0
39
+ python-multipart==0.0.9
40
+ pytz==2025.2
41
+ pyyaml==6.0.2
42
+ referencing==0.36.2
43
+ regex==2024.11.6
44
+ requests==2.31.0
45
+ rich==13.9.4
46
+ rpds-py==0.25.1
47
+ safetensors==0.5.3
48
+ setuptools==69.2.0
49
+ six==1.17.0
50
+ smmap==5.0.2
51
+ sniffio==1.3.1
52
+ starlette==0.36.3
53
+ streamlit==1.32.0
54
+ sympy==1.14.0
55
+ tenacity==8.5.0
56
+ tokenizers==0.21.1
57
+ toml==0.10.2
58
+ torch==2.7.0
59
+ torchvision==0.22.0
60
+ tornado==6.5.1
61
+ tqdm==4.67.1
62
+ transformers==4.52.3
63
+ typing-extensions==4.13.2
64
+ typing-inspection==0.4.1
65
+ tzdata==2025.2
66
+ urllib3==2.4.0
67
+ uvicorn==0.27.1
68
+ watchdog==6.0.0
src/similarity.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import requests
7
+ import base64
8
+ from collections import defaultdict
9
+ from transformers import ViTModel, ViTImageProcessor
10
+
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ import logging
15
+ # logging.disable(logging.WARNING)
16
+ transformers_logger = logging.getLogger('transformers')
17
+ transformers_logger.setLevel(logging.ERROR)
18
+
19
+
20
+
21
+ # Change current dir to the execution place
22
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
23
+ DB_PATH_STRUCTURE = 'embeddings/pokemon_embeddings_pkmn.pkl'
24
+
25
+
26
+
27
+
28
+
29
+ # --- Device Selection ---
30
+ # Hint: Check for CUDA, MPS, or fallback to CPU
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+
34
+
35
+
36
+ # --- Load Pretrained Model ---
37
+ def get_model() -> ViTModel:
38
+ """
39
+ TODO: Implement model loading
40
+ - Load a pretrained model (e.g., ResNet18)
41
+ - Remove the classification head
42
+ - Set the model to evaluation mode
43
+ - Move the model to the appropriate device
44
+
45
+ Returns:
46
+ torch.nn.Module: The prepared model
47
+ """
48
+
49
+ model = ViTModel.from_pretrained('imjeffhi/pokemon_classifier').to(device)
50
+
51
+
52
+ return model.eval()
53
+
54
+
55
+
56
+
57
+
58
+ # --- Image Preprocessing ---
59
+ # TODO: Define your image transformation pipeline
60
+ # Hint: Consider resizing, normalization, and tensor conversion
61
+ transform = ViTImageProcessor.from_pretrained(get_model().name_or_path)
62
+
63
+
64
+
65
+
66
+
67
+ class PokemonSimilarity:
68
+ def __init__(self, suppress_init_logs: bool = True) -> None:
69
+ """
70
+ TODO: Initialize the similarity engine
71
+ - Load the model
72
+ - Load the database of Pokemon embeddings
73
+ """
74
+ self.original_transformers_level = transformers_logger.level
75
+ self.original_root_level = logging.root.level
76
+
77
+ if suppress_init_logs:
78
+ # Temporarily raise the logging level for transformers and root logger
79
+ # to silence startup messages during model/DB loading
80
+ transformers_logger.setLevel(logging.ERROR)
81
+ logging.root.setLevel(logging.WARNING) # Suppress INFO from other sources too
82
+
83
+ try:
84
+ self.model = get_model()
85
+ self.db = self._load_db()
86
+
87
+ finally:
88
+ # Always restore original logging levels after initialization
89
+ transformers_logger.setLevel(self.original_transformers_level)
90
+ logging.root.setLevel(self.original_root_level)
91
+
92
+
93
+
94
+ def _load_db(self) -> dict | None:
95
+ """
96
+ TODO: Implement database loading
97
+ - Look for the embeddings file in different possible locations
98
+ - Load the pickle file containing Pokemon embeddings
99
+ - Handle cases where the file is not found
100
+
101
+ Returns:
102
+ list: List of dictionaries containing Pokemon embeddings and labels
103
+ """
104
+
105
+ db_path = None
106
+
107
+ try:
108
+
109
+ if os.path.exists(DB_PATH_STRUCTURE):
110
+ db_path = DB_PATH_STRUCTURE
111
+
112
+ if os.path.exists(f'../{DB_PATH_STRUCTURE}'):
113
+ db_path = f'../{DB_PATH_STRUCTURE}'
114
+
115
+ with open(db_path, 'rb') as f:
116
+ # Load the dictionary from the file
117
+ embeddings = pickle.load(f)
118
+
119
+ return embeddings
120
+
121
+ except Exception as e:
122
+ raise os.error(f'Error loading embeddings database: {e}')
123
+
124
+
125
+
126
+ def load_image(self, image_input) -> Image.Image:
127
+ """
128
+ Handle different input formats:
129
+ - URL strings
130
+ - Base64 encoded image strings
131
+ - Bytes objects
132
+ - PIL Image objects
133
+
134
+ Args:
135
+ image_input: Image in various formats
136
+
137
+ Returns:
138
+ PIL.Image: The loaded image in RGB format
139
+ """
140
+ if isinstance(image_input, Image.Image):
141
+ # Already a PIL Image object
142
+ return image_input.convert('RGB')
143
+
144
+ elif isinstance(image_input, str):
145
+ # Check if it's a local file path
146
+ if os.path.exists(image_input):
147
+ try:
148
+ return Image.open(image_input).convert('RGB')
149
+
150
+ except Image.UnidentifiedImageError as e:
151
+ raise Image.UnidentifiedImageError(f"Cannot identify image file at path '{image_input}': {e}")
152
+
153
+ except Exception as e:
154
+ raise ValueError(f"Error loading image from local file path '{image_input}': {e}")
155
+
156
+ # Check if it's a URL
157
+ elif image_input.startswith(('http://', 'https://')):
158
+ try:
159
+ response = requests.get(image_input, stream=True)
160
+ response.raise_for_status() # Raise an exception for bad status codes
161
+ return Image.open(io.BytesIO(response.content)).convert('RGB')
162
+
163
+ except requests.RequestException as e:
164
+ raise requests.RequestException(f"Error loading image from URL '{image_input}': {e}")
165
+
166
+ except Exception as e:
167
+ raise ValueError(f"Error processing image from URL '{image_input}': {e}")
168
+
169
+ # Check if it's a Base64 encoded string
170
+ try:
171
+ # Base64 strings often include a prefix like "data:image/jpeg;base64,"
172
+ # We need to remove that prefix before decoding.
173
+ if ',' in image_input:
174
+ _, base64_data = image_input.split(',', 1)
175
+ else:
176
+ base64_data = image_input
177
+
178
+ decoded_image = base64.b64decode(base64_data)
179
+ return Image.open(io.BytesIO(decoded_image)).convert('RGB')
180
+
181
+ except (base64.binascii.Error, ValueError) as e:
182
+ # If it's not a valid Base64, it might just be an unsupported string
183
+ # We'll let the final ValueError catch it if no other type matches.
184
+ pass # Continue to check other types or raise final error
185
+
186
+ elif isinstance(image_input, bytes):
187
+ # Bytes object
188
+ try:
189
+ return Image.open(io.BytesIO(image_input)).convert('RGB')
190
+
191
+ except Exception as e:
192
+ raise ValueError(f'Error loading image from bytes object: {e}')
193
+
194
+ raise ValueError(f'Unsupported image input format: {type(image_input)}. Expected URL, Base64 string, bytes, or PIL Image.')
195
+
196
+
197
+
198
+ def get_embedding(self, image) -> torch.Tensor:
199
+ """
200
+ Generate a feature vector for the input image using the model
201
+
202
+ Args:
203
+ image (PIL.Image): Input image to generate embedding for
204
+
205
+ Returns:
206
+ numpy.ndarray: The image embedding
207
+ """
208
+
209
+ inputs = transform(images=image, return_tensors="pt").to(device)
210
+
211
+ last_hidden_state = self.model(**inputs).last_hidden_state
212
+
213
+ return last_hidden_state.reshape(last_hidden_state.shape[0], -1)
214
+
215
+
216
+ def cosine_similarity(self, a, b) -> float:
217
+ """
218
+ Calculate the cosine similarity between two vectors
219
+
220
+ Args:
221
+ a: First vector
222
+ b: Second vector
223
+
224
+ Returns:
225
+ float: Cosine similarity score
226
+ """
227
+
228
+ return float(torch.nn.functional.cosine_similarity(a, b, dim=1))
229
+
230
+
231
+ def find_closest_pokemon(self, image_input):
232
+ """
233
+ 1. Load the input image
234
+ 2. Generate its embedding
235
+ 3. Compare with all Pokemon embeddings in the database
236
+ 4. Return the name of the most similar Pokemon
237
+
238
+ Args:
239
+ image_input: Image in various formats (URL, base64, bytes, PIL Image)
240
+
241
+ Returns:
242
+ str: Name of the most similar Pokemon
243
+ """
244
+
245
+ # Load the input_image
246
+ image = self.load_image(image_input)
247
+
248
+ # Generate embedding for the input image
249
+ input_emb = self.get_embedding(image)
250
+
251
+ # Compute similarities with all database entries
252
+ similarities = []
253
+ for label, emb_list in self.db.items():
254
+ for emb in emb_list:
255
+ similarities.append((
256
+ label,
257
+ self.cosine_similarity(input_emb, emb)
258
+ ))
259
+
260
+ # Sort by similarity, descending
261
+ similarities.sort(key=lambda x: x[1], reverse=True)
262
+
263
+ # Majority voting
264
+ data = lambda: defaultdict(float)
265
+ summary = defaultdict(data)
266
+ for label, similarity in similarities[:5]:
267
+ summary[label]['votes'] += 1
268
+ summary[label]['max_sim'] = max(summary[label]['max_sim'], similarity)
269
+
270
+ # Sort by votes, descending. In draw case prior max_similarity
271
+ sorted_votes = [(label, data['votes'], data['max_sim']) for label, data in summary.items()]
272
+ sorted_votes.sort(key=lambda x: (x[1], x[2]), reverse=True)
273
+
274
+ return sorted_votes[0][0]
275
+
276
+
277
+
278
+
279
+
280
+ if __name__ == "__main__":
281
+ similarity_engine = PokemonSimilarity()
282
+ print(similarity_engine.find_closest_pokemon('https://alfabetajuega.com/hero/2019/03/Squirtle-Looking-Happy.jpg?width=1200&aspect_ratio=16:9'))
283
+ # print(similarity_engine.find_closest_pokemon(r'C:\python\intro_deep_learning\hackathon\solutions\grupo_delante\data\testing\charmander\charmander.jpeg'))
284
+
src/streamlit_app.py CHANGED
@@ -1,40 +1,113 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ from similarity import PokemonSimilarity
4
+ import logging
5
 
6
+
7
+ INPUT_UPLOAD = 'upload'
8
+ INPUT_URL = 'url'
9
+
10
+ # Configure logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # Set page config first
19
+ st.set_page_config(
20
+ page_title='Pokemon Similarity Finder',
21
+ page_icon='🎮',
22
+ layout='centered'
23
+ )
24
+
25
+
26
+ # Initialize the similarity engine
27
+ @st.cache_resource
28
+ def get_similarity_engine() -> PokemonSimilarity:
29
+ logger.info('Initializing similarity engine...')
30
+ engine = PokemonSimilarity()
31
+ logger.info('Similarity engine initialized successfully')
32
+ return engine
33
+
34
+
35
+ similarity_engine = get_similarity_engine()
36
+
37
+
38
+ # Title and description
39
+ st.title('🎮 Pokemon Similarity Finder')
40
+ st.markdown("""
41
+ Upload an image of a Pokemon or provide an image URL and we'll find the closest match in our database!
42
+ """)
43
+
44
+
45
+ # --- Input Method Selection ---
46
+ input_method = st.radio(
47
+ 'Choose input method:',
48
+ ('Upload Image', 'Image URL'),
49
+ horizontal=True
50
+ )
51
+
52
+
53
+ # --- Initialize variables for shared logic ---
54
+ input_type = None
55
+ image_to_process = None
56
+ # request_payload = None
57
+ # request_files = None
58
+
59
+
60
+ if input_method == 'Upload Image':
61
+ uploaded_file = st.file_uploader('Choose a Pokemon image...', type=['jpg', 'jpeg', 'png'])
62
+
63
+ if uploaded_file is not None:
64
+ logger.info(f'File uploaded: {uploaded_file.name}')
65
+ input_type = INPUT_UPLOAD
66
+ image_input = uploaded_file.getvalue()
67
+ image_to_process = Image.open(uploaded_file)
68
+ # request_files = {'file': (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
69
+
70
+
71
+ if input_method == 'Image URL':
72
+ image_url = st.text_input('Enter Image URL:')
73
+
74
+ if image_url:
75
+ logger.info(f'Image URL provided: {image_url}')
76
+ input_type = INPUT_URL
77
+ image_input = image_url
78
+ image_to_process = image_url
79
+ # request_payload = json.dumps({'url': image_url})
80
+
81
+
82
+
83
+ if image_to_process is not None:
84
+ try:
85
+ st.image(image_to_process, caption=f'Image from {input_type}', use_column_width=True)
86
+ logger.info(f'Successfully displayed {input_type} image')
87
+
88
+ except Exception as e:
89
+ logger.error(f'Error loading image: {str(e)}')
90
+ st.error(f'❌ Error loading image: {str(e)}')
91
+ st.info('Please make sure you have uploaded a valid image file.')
92
+
93
+
94
+ # Add a button to trigger the similarity search
95
+ if st.button('Find Similar Pokemon', use_container_width=True):
96
+ logger.info('Find Similar Pokemon button clicked')
97
+ predicted_pokemon = None # Reset prediction
98
+
99
+ with st.spinner('Analyzing image...'):
100
+ try:
101
+ logger.info(f'Finding closest Pokemon match using {input_type} input...')
102
+
103
+ predicted_pokemon = similarity_engine.find_closest_pokemon(image_input)
104
+
105
+ except Exception as e:
106
+ logger.error(f'Error during Pokemon matching: {str(e)}')
107
+ st.error(f'❌ An error occurred: {str(e)}')
108
+ st.info('Please try uploading a different image, using a different URL, or try again later.')
109
+
110
+ if predicted_pokemon:
111
+ logger.info(f'Found closest Pokemon: {predicted_pokemon}')
112
+ st.success(f'🎯 The closest Pokemon is: **{predicted_pokemon.title()}**')
113
+ st.balloons()