File size: 5,063 Bytes
bc492d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import onnxruntime as ort
import numpy as np
import json
from PIL import Image

def preprocess_image(img_path, target_size=512, keep_aspect=True):
    """

    Load an image from img_path, convert to RGB,

    and resize/pad to (target_size, target_size).

    Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.

    """
    img = Image.open(img_path).convert("RGB")
    
    if keep_aspect:
        # Preserve aspect ratio, pad black
        w, h = img.size
        aspect = w / h
        if aspect > 1:
            new_w = target_size
            new_h = int(new_w / aspect)
        else:
            new_h = target_size
            new_w = int(new_h * aspect)

        # Resize with Lanczos
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
        # Pad to a square
        background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
        paste_x = (target_size - new_w) // 2
        paste_y = (target_size - new_h) // 2
        background.paste(img, (paste_x, paste_y))
        img = background
    else:
        # simple direct resize to 512x512
        img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
    
    # Convert to numpy array
    arr = np.array(img).astype("float32") / 255.0  # scale to [0,1]
    # Transpose from HWC -> CHW
    arr = np.transpose(arr, (2, 0, 1))
    # Add batch dimension: (1,3,512,512)
    arr = np.expand_dims(arr, axis=0)
    return arr

def onnx_inference(img_paths, 

                   onnx_path="camie_refined_no_flash.onnx", 

                   threshold=0.325, 

                   metadata_file="metadata.json"):
    """

    Loads the ONNX model, runs inference on a list of image paths,

    and applies an optional threshold to produce final predictions.

    

    Args:

      img_paths: List of paths to images.

      onnx_path: Path to the exported ONNX model file.

      threshold: Probability threshold for deciding if a tag is predicted.

      metadata_file: Path to metadata.json that contains idx_to_tag etc.

    

    Returns:

      A list of dicts, each containing:

        {

          "initial_logits": np.ndarray of shape (N_tags,),

          "refined_logits": np.ndarray of shape (N_tags,),

          "predicted_tags": list of tag indices that exceeded threshold,

          ...

        }

      one dict per input image.

    """
    # 1) Initialize ONNX runtime session
    session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    # Optional: for GPU usage, see if "CUDAExecutionProvider" is available
    # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])

    # 2) Pre-load metadata
    with open(metadata_file, "r", encoding="utf-8") as f:
        metadata = json.load(f)
    idx_to_tag = metadata["idx_to_tag"]  # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }

    # 3) Preprocess each image into a batch
    batch_tensors = []
    for img_path in img_paths:
        x = preprocess_image(img_path, target_size=512, keep_aspect=True)
        batch_tensors.append(x)
    # Concatenate along the batch dimension => shape (batch_size, 3, 512, 512)
    batch_input = np.concatenate(batch_tensors, axis=0)

    # 4) Run inference
    input_name = session.get_inputs()[0].name        # typically "image"
    outputs = session.run(None, {input_name: batch_input})
    # Typically we get [initial_tags, refined_tags] as output
    initial_preds, refined_preds = outputs  # shapes => (batch_size, 70527)

    # 5) For each image in batch, convert logits to predictions if desired
    batch_results = []
    for i in range(initial_preds.shape[0]):
        # Extract one sample's logits
        init_logit = initial_preds[i, :]   # shape (N_tags,)
        ref_logit = refined_preds[i, :]    # shape (N_tags,)

        # Convert to probabilities with sigmoid
        ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))

        # Threshold
        pred_indices = np.where(ref_prob >= threshold)[0]

        # Build result for this image
        result_dict = {
            "initial_logits": init_logit,
            "refined_logits": ref_logit,
            "predicted_indices": pred_indices,
            "predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices]  # map index->tag name
        }
        batch_results.append(result_dict)

    return batch_results

if __name__ == "__main__":
    # Example usage
    images = ["image1.jpg", "image2.jpg", "image3.jpg"]
    results = onnx_inference(images,
                             onnx_path="camie_refined_no_flash.onnx",
                             threshold=0.325,
                             metadata_file="metadata.json")

    for i, res in enumerate(results):
        print(f"Image: {images[i]}")
        print(f"  # of predicted tags above threshold: {len(res['predicted_indices'])}")
        print(f"  Some predicted tags: {res['predicted_tags'][:10]}  (Show up to 10)")
        print()