|
import onnxruntime as ort |
|
import numpy as np |
|
import json |
|
from PIL import Image |
|
|
|
|
|
session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"]) |
|
|
|
|
|
def preprocess_image(img_path): |
|
""" |
|
Loads and resizes an image to 512x512, converts it to float32 [0..1], |
|
and returns a (1,3,512,512) NumPy array (NCHW format). |
|
""" |
|
img = Image.open(img_path).convert("RGB").resize((512, 512)) |
|
x = np.array(img).astype(np.float32) / 255.0 |
|
x = np.transpose(x, (2, 0, 1)) |
|
x = np.expand_dims(x, 0) |
|
return x |
|
|
|
|
|
|
|
def inference(input_path, output_format="verbose"): |
|
""" |
|
Returns either: |
|
- A verbose category breakdown, or |
|
- A comma-separated string of predicted tags (underscores replaced with spaces). |
|
""" |
|
|
|
input_tensor = preprocess_image(input_path) |
|
|
|
|
|
input_name = session.get_inputs()[0].name |
|
outputs = session.run(None, {input_name: input_tensor}) |
|
initial_logits, refined_logits = outputs |
|
|
|
|
|
refined_probs = 1 / (1 + np.exp(-refined_logits)) |
|
|
|
|
|
with open("metadata.json", "r", encoding="utf-8") as f: |
|
metadata = json.load(f) |
|
|
|
idx_to_tag = metadata["idx_to_tag"] |
|
tag_to_category = metadata.get("tag_to_category", {}) |
|
category_thresholds = metadata.get( |
|
"category_thresholds", |
|
{"artist": 0.1, "character": 0.2, "meta": 0.3, "style": 0.1} |
|
) |
|
default_threshold = 0.325 |
|
|
|
|
|
results_by_category = {} |
|
num_tags = refined_probs.shape[1] |
|
|
|
for i in range(num_tags): |
|
prob = float(refined_probs[0, i]) |
|
tag_name = idx_to_tag[str(i)] |
|
category = tag_to_category.get(tag_name, "unknown") |
|
cat_threshold = category_thresholds.get(category, default_threshold) |
|
|
|
if prob >= cat_threshold: |
|
if category not in results_by_category: |
|
results_by_category[category] = [] |
|
results_by_category[category].append((tag_name, prob)) |
|
|
|
|
|
if output_format == "as_prompt": |
|
|
|
all_predicted_tags = [] |
|
for cat, tags_list in results_by_category.items(): |
|
|
|
for tname, tprob in tags_list: |
|
|
|
tag_name_spaces = tname.replace("_", " ") |
|
all_predicted_tags.append(tag_name_spaces) |
|
|
|
|
|
prompt_string = ", ".join(all_predicted_tags) |
|
return prompt_string |
|
|
|
else: |
|
|
|
lines = [] |
|
lines.append("Predicted Tags by Category:\n") |
|
for cat, tags_list in results_by_category.items(): |
|
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags") |
|
|
|
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True): |
|
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}") |
|
lines.append("") |
|
|
|
verbose_output = "\n".join(lines) |
|
return verbose_output |
|
|
|
if __name__ == "__main__": |
|
result = inference("path/to/image", output_format="as_prompt") |
|
print(result) |