grimbano commited on
Commit
832cf3d
·
1 Parent(s): d03083f

fix: :memo: Refactor docs to fix problems

Browse files
Files changed (4) hide show
  1. .python-version +0 -1
  2. Dockerfile +1 -1
  3. requirements.txt +4 -11
  4. src/similarity.py +4 -4
.python-version DELETED
@@ -1 +0,0 @@
1
- 3.11
 
 
Dockerfile CHANGED
@@ -30,8 +30,8 @@ COPY requirements.txt ./
30
  COPY src/ ./src/
31
  COPY embeddings/ ./embeddings/
32
 
33
- RUN pip3 install --user -r requirements.txt
34
  ENV PATH=$HOME/.local/bin:$PATH
 
35
 
36
  EXPOSE 8501
37
 
 
30
  COPY src/ ./src/
31
  COPY embeddings/ ./embeddings/
32
 
 
33
  ENV PATH=$HOME/.local/bin:$PATH
34
+ RUN pip3 install --user -r requirements.txt
35
 
36
  EXPOSE 8501
37
 
requirements.txt CHANGED
@@ -1,12 +1,5 @@
1
- numpy==1.26.4
2
- packaging==23.2
3
- setuptools==69.2.0
4
- fastapi==0.109.2
5
- uvicorn==0.27.1
6
- python-multipart==0.0.9
7
- requests==2.31.0
8
- torch==2.2.1
9
- torchvision==0.17.1
10
- Pillow==10.2.0
11
- streamlit==1.32.0
12
  transformers
 
1
+ pillow
2
+ requests
3
+ streamlit
4
+ torch
 
 
 
 
 
 
 
5
  transformers
src/similarity.py CHANGED
@@ -35,7 +35,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
35
  # --- Load Pretrained Model ---
36
  def get_model() -> ViTModel:
37
  """
38
- TODO: Implement model loading
39
  - Load a pretrained model (e.g., ResNet18)
40
  - Remove the classification head
41
  - Set the model to evaluation mode
@@ -44,6 +43,10 @@ def get_model() -> ViTModel:
44
  Returns:
45
  torch.nn.Module: The prepared model
46
  """
 
 
 
 
47
 
48
  model = ViTModel.from_pretrained('imjeffhi/pokemon_classifier').to(device)
49
 
@@ -55,7 +58,6 @@ def get_model() -> ViTModel:
55
 
56
 
57
  # --- Image Preprocessing ---
58
- # TODO: Define your image transformation pipeline
59
  # Hint: Consider resizing, normalization, and tensor conversion
60
  transform = ViTImageProcessor.from_pretrained(get_model().name_or_path)
61
 
@@ -66,7 +68,6 @@ transform = ViTImageProcessor.from_pretrained(get_model().name_or_path)
66
  class PokemonSimilarity:
67
  def __init__(self, suppress_init_logs: bool = True) -> None:
68
  """
69
- TODO: Initialize the similarity engine
70
  - Load the model
71
  - Load the database of Pokemon embeddings
72
  """
@@ -92,7 +93,6 @@ class PokemonSimilarity:
92
 
93
  def _load_db(self) -> dict | None:
94
  """
95
- TODO: Implement database loading
96
  - Look for the embeddings file in different possible locations
97
  - Load the pickle file containing Pokemon embeddings
98
  - Handle cases where the file is not found
 
35
  # --- Load Pretrained Model ---
36
  def get_model() -> ViTModel:
37
  """
 
38
  - Load a pretrained model (e.g., ResNet18)
39
  - Remove the classification head
40
  - Set the model to evaluation mode
 
43
  Returns:
44
  torch.nn.Module: The prepared model
45
  """
46
+ # Ensure cache directory exists
47
+ cache_dir = os.getenv('TORCH_HOME', os.path.expanduser('~/.cache/torch'))
48
+ os.makedirs(cache_dir, exist_ok=True)
49
+ transformers_logger.info(f"Using cache directory: {cache_dir}")
50
 
51
  model = ViTModel.from_pretrained('imjeffhi/pokemon_classifier').to(device)
52
 
 
58
 
59
 
60
  # --- Image Preprocessing ---
 
61
  # Hint: Consider resizing, normalization, and tensor conversion
62
  transform = ViTImageProcessor.from_pretrained(get_model().name_or_path)
63
 
 
68
  class PokemonSimilarity:
69
  def __init__(self, suppress_init_logs: bool = True) -> None:
70
  """
 
71
  - Load the model
72
  - Load the database of Pokemon embeddings
73
  """
 
93
 
94
  def _load_db(self) -> dict | None:
95
  """
 
96
  - Look for the embeddings file in different possible locations
97
  - Load the pickle file containing Pokemon embeddings
98
  - Handle cases where the file is not found