File size: 3,803 Bytes
ca9b012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import onnxruntime as ort
import numpy as np
import json
from PIL import Image
# 1) Load ONNX model
session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])
# 2) Preprocess your image (512x512, etc.)
def preprocess_image(img_path):
"""
Loads and resizes an image to 512x512, converts it to float32 [0..1],
and returns a (1,3,512,512) NumPy array (NCHW format).
"""
img = Image.open(img_path).convert("RGB").resize((512, 512))
x = np.array(img).astype(np.float32) / 255.0
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
return x
# Example input
def inference(input_path, output_format="verbose"):
"""
Returns either:
- A verbose category breakdown, or
- A comma-separated string of predicted tags (underscores replaced with spaces).
"""
# 1) Preprocess
input_tensor = preprocess_image(input_path)
# 2) Run inference
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_tensor})
initial_logits, refined_logits = outputs # shape: (1, 70527) each
# 3) Convert logits to probabilities
refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
# 4) Load metadata & retrieve threshold info
with open("metadata.json", "r", encoding="utf-8") as f:
metadata = json.load(f)
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get(
"category_thresholds",
{"artist": 0.1, "character": 0.2, "meta": 0.3, "style": 0.1}
)
default_threshold = 0.325
# 5) Collect predictions by category
results_by_category = {}
num_tags = refined_probs.shape[1]
for i in range(num_tags):
prob = float(refined_probs[0, i])
tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
category = tag_to_category.get(tag_name, "unknown")
cat_threshold = category_thresholds.get(category, default_threshold)
if prob >= cat_threshold:
if category not in results_by_category:
results_by_category[category] = []
results_by_category[category].append((tag_name, prob))
# 6) Depending on output_format, produce different return strings
if output_format == "as_prompt":
# Flatten all predicted tags across categories
all_predicted_tags = []
for cat, tags_list in results_by_category.items():
# We only need the tag name in as_prompt format
for tname, tprob in tags_list:
# convert underscores to spaces
tag_name_spaces = tname.replace("_", " ")
all_predicted_tags.append(tag_name_spaces)
# Create a comma-separated string
prompt_string = ", ".join(all_predicted_tags)
return prompt_string
else: # "verbose"
# We'll build a multiline string describing the predictions
lines = []
lines.append("Predicted Tags by Category:\n")
for cat, tags_list in results_by_category.items():
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
# Sort descending by probability
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
lines.append("") # blank line after each category
# Join lines with newlines
verbose_output = "\n".join(lines)
return verbose_output
if __name__ == "__main__":
result = inference("path/to/image", output_format="as_prompt")
print(result) |