File size: 5,063 Bytes
bc492d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import onnxruntime as ort
import numpy as np
import json
from PIL import Image
def preprocess_image(img_path, target_size=512, keep_aspect=True):
Load an image from img_path, convert to RGB,
and resize/pad to (target_size, target_size).
Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.
img ="RGB")
if keep_aspect:
# Preserve aspect ratio, pad black
w, h = img.size
aspect = w / h
if aspect > 1:
new_w = target_size
new_h = int(new_w / aspect)
new_h = target_size
new_w = int(new_h * aspect)
# Resize with Lanczos
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Pad to a square
background ="RGB", (target_size, target_size), (0, 0, 0))
paste_x = (target_size - new_w) // 2
paste_y = (target_size - new_h) // 2
background.paste(img, (paste_x, paste_y))
img = background
# simple direct resize to 512x512
img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
# Convert to numpy array
arr = np.array(img).astype("float32") / 255.0 # scale to [0,1]
# Transpose from HWC -> CHW
arr = np.transpose(arr, (2, 0, 1))
# Add batch dimension: (1,3,512,512)
arr = np.expand_dims(arr, axis=0)
return arr
def onnx_inference(img_paths,
Loads the ONNX model, runs inference on a list of image paths,
and applies an optional threshold to produce final predictions.
img_paths: List of paths to images.
onnx_path: Path to the exported ONNX model file.
threshold: Probability threshold for deciding if a tag is predicted.
metadata_file: Path to metadata.json that contains idx_to_tag etc.
A list of dicts, each containing:
"initial_logits": np.ndarray of shape (N_tags,),
"refined_logits": np.ndarray of shape (N_tags,),
"predicted_tags": list of tag indices that exceeded threshold,
one dict per input image.
# 1) Initialize ONNX runtime session
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
# Optional: for GPU usage, see if "CUDAExecutionProvider" is available
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
# 2) Pre-load metadata
with open(metadata_file, "r", encoding="utf-8") as f:
metadata = json.load(f)
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
# 3) Preprocess each image into a batch
batch_tensors = []
for img_path in img_paths:
x = preprocess_image(img_path, target_size=512, keep_aspect=True)
# Concatenate along the batch dimension => shape (batch_size, 3, 512, 512)
batch_input = np.concatenate(batch_tensors, axis=0)
# 4) Run inference
input_name = session.get_inputs()[0].name # typically "image"
outputs =, {input_name: batch_input})
# Typically we get [initial_tags, refined_tags] as output
initial_preds, refined_preds = outputs # shapes => (batch_size, 70527)
# 5) For each image in batch, convert logits to predictions if desired
batch_results = []
for i in range(initial_preds.shape[0]):
# Extract one sample's logits
init_logit = initial_preds[i, :] # shape (N_tags,)
ref_logit = refined_preds[i, :] # shape (N_tags,)
# Convert to probabilities with sigmoid
ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
# Threshold
pred_indices = np.where(ref_prob >= threshold)[0]
# Build result for this image
result_dict = {
"initial_logits": init_logit,
"refined_logits": ref_logit,
"predicted_indices": pred_indices,
"predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name
return batch_results
if __name__ == "__main__":
# Example usage
images = ["image1.jpg", "image2.jpg", "image3.jpg"]
results = onnx_inference(images,
for i, res in enumerate(results):
print(f"Image: {images[i]}")
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)")