mjwong commited on
Commit
1f094d9
·
verified ·
1 Parent(s): 8567ffe

Upload helpers.py

Browse files
Files changed (1) hide show
  1. helpers.py +61 -0
helpers.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Callable
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+
7
+ def l2_normalize(embedding: np.ndarray) -> np.ndarray:
8
+ """Normalize vector using L2 norm.
9
+
10
+ Args:
11
+ embedding (np.ndarray): Input vector to normalize.
12
+
13
+ Returns:
14
+ np.ndarray: Normalized vector.
15
+ """
16
+ # Compute the L2 norm of the input vector
17
+ norm = np.linalg.norm(embedding)
18
+
19
+ # Return the normalized vector if norm is greater than 0;
20
+ # otherwise, return the original vector
21
+ return embedding / norm if norm > 0 else embedding
22
+
23
+ def encode_image(
24
+ image: Image.Image,
25
+ preprocess: Callable[[Image.Image], torch.Tensor],
26
+ model: torch.nn.Module,
27
+ device: torch.device,
28
+ ) -> List[float]:
29
+ """Preprocess and encode an image using input model.
30
+
31
+ This function performs the following steps:
32
+ 1. Preprocess the image to create a tensor.
33
+ 2. Move the tensor to the specified device (CPU or GPU).
34
+ 3. Generate image features using the model.
35
+ 4. Normalize the resulting embedding.
36
+
37
+ Args:
38
+ image (Image.Image): Input image to encode.
39
+ preprocess (Callable[[Image.Image], torch.Tensor]):
40
+ A callable function to preprocess the image.
41
+ model (torch.nn.Module): The model used for encoding.
42
+ device (torch.device): The device to which the image tensor is sent.
43
+
44
+ Returns:
45
+ List[float]: A list representing the normalized embedding.
46
+ """
47
+ # Preprocess the input image and add a batch dimension
48
+ image_input = preprocess(image).unsqueeze(0).to(device)
49
+
50
+ # Use the model to encode the image without computing gradients
51
+ with torch.no_grad():
52
+ image_features = model.encode_image(image_input)
53
+
54
+ # Extract the first (and only) embedding from the batch and move it to CPU
55
+ embedding = image_features[0].cpu().numpy()
56
+
57
+ # Normalize the embedding using L2 normalization
58
+ embedding_norm = l2_normalize(embedding)
59
+
60
+ # Convert the normalized NumPy array to a list and return it
61
+ return embedding_norm.tolist()