import onnxruntime as ort import numpy as np import json from PIL import Image def preprocess_image(img_path, target_size=512, keep_aspect=True): """ Load an image from img_path, convert to RGB, and resize/pad to (target_size, target_size). Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array. """ img = Image.open(img_path).convert("RGB") if keep_aspect: # Preserve aspect ratio, pad black w, h = img.size aspect = w / h if aspect > 1: new_w = target_size new_h = int(new_w / aspect) else: new_h = target_size new_w = int(new_h * aspect) # Resize with Lanczos img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) # Pad to a square background = Image.new("RGB", (target_size, target_size), (0, 0, 0)) paste_x = (target_size - new_w) // 2 paste_y = (target_size - new_h) // 2 background.paste(img, (paste_x, paste_y)) img = background else: # simple direct resize to 512x512 img = img.resize((target_size, target_size), Image.Resampling.LANCZOS) # Convert to numpy array arr = np.array(img).astype("float32") / 255.0 # scale to [0,1] # Transpose from HWC -> CHW arr = np.transpose(arr, (2, 0, 1)) # Add batch dimension: (1,3,512,512) arr = np.expand_dims(arr, axis=0) return arr def onnx_inference(img_paths, onnx_path="camie_refined_no_flash.onnx", threshold=0.325, metadata_file="metadata.json"): """ Loads the ONNX model, runs inference on a list of image paths, and applies an optional threshold to produce final predictions. Args: img_paths: List of paths to images. onnx_path: Path to the exported ONNX model file. threshold: Probability threshold for deciding if a tag is predicted. metadata_file: Path to metadata.json that contains idx_to_tag etc. Returns: A list of dicts, each containing: { "initial_logits": np.ndarray of shape (N_tags,), "refined_logits": np.ndarray of shape (N_tags,), "predicted_tags": list of tag indices that exceeded threshold, ... } one dict per input image. """ # 1) Initialize ONNX runtime session session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) # Optional: for GPU usage, see if "CUDAExecutionProvider" is available # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"]) # 2) Pre-load metadata with open(metadata_file, "r", encoding="utf-8") as f: metadata = json.load(f) idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... } # 3) Preprocess each image into a batch batch_tensors = [] for img_path in img_paths: x = preprocess_image(img_path, target_size=512, keep_aspect=True) batch_tensors.append(x) # Concatenate along the batch dimension => shape (batch_size, 3, 512, 512) batch_input = np.concatenate(batch_tensors, axis=0) # 4) Run inference input_name = session.get_inputs()[0].name # typically "image" outputs = session.run(None, {input_name: batch_input}) # Typically we get [initial_tags, refined_tags] as output initial_preds, refined_preds = outputs # shapes => (batch_size, 70527) # 5) For each image in batch, convert logits to predictions if desired batch_results = [] for i in range(initial_preds.shape[0]): # Extract one sample's logits init_logit = initial_preds[i, :] # shape (N_tags,) ref_logit = refined_preds[i, :] # shape (N_tags,) # Convert to probabilities with sigmoid ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # Threshold pred_indices = np.where(ref_prob >= threshold)[0] # Build result for this image result_dict = { "initial_logits": init_logit, "refined_logits": ref_logit, "predicted_indices": pred_indices, "predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name } batch_results.append(result_dict) return batch_results if __name__ == "__main__": # Example usage images = ["image1.jpg", "image2.jpg", "image3.jpg"] results = onnx_inference(images, onnx_path="camie_refined_no_flash.onnx", threshold=0.325, metadata_file="metadata.json") for i, res in enumerate(results): print(f"Image: {images[i]}") print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}") print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)") print()