|
import gradio as gr |
|
import onnxruntime as ort |
|
import numpy as np |
|
from PIL import Image |
|
import json |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime" |
|
MODEL_FILE = "camie_tagger_initial.onnx" |
|
META_FILE = "metadata.json" |
|
|
|
|
|
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".") |
|
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".") |
|
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
metadata = json.load(open(meta_path, "r", encoding="utf-8")) |
|
|
|
|
|
def preprocess_image(pil_image: Image.Image) -> np.ndarray: |
|
img = pil_image.convert("RGB").resize((512, 512)) |
|
arr = np.array(img).astype(np.float32) / 255.0 |
|
arr = np.transpose(arr, (2, 0, 1)) |
|
arr = np.expand_dims(arr, 0) |
|
return arr |
|
|
|
|
|
def predict_tags(pil_image: Image.Image) -> str: |
|
|
|
input_tensor = preprocess_image(pil_image) |
|
|
|
input_name = session.get_inputs()[0].name |
|
initial_logits, refined_logits = session.run(None, {input_name: input_tensor}) |
|
|
|
probs = 1 / (1 + np.exp(-refined_logits)) |
|
probs = probs[0] |
|
|
|
idx_to_tag = metadata["idx_to_tag"] |
|
tag_to_category = metadata.get("tag_to_category", {}) |
|
category_thresholds = metadata.get("category_thresholds", {}) |
|
default_threshold = 0.325 |
|
predicted_tags = [] |
|
for idx, prob in enumerate(probs): |
|
tag = idx_to_tag[str(idx)] |
|
cat = tag_to_category.get(tag, "unknown") |
|
threshold = category_thresholds.get(cat, default_threshold) |
|
if prob >= threshold: |
|
|
|
predicted_tags.append(tag.replace("_", " ")) |
|
|
|
if not predicted_tags: |
|
return "No tags found." |
|
|
|
predicted_tags.sort() |
|
return ", ".join(predicted_tags) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_tags, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Textbox(label="Predicted Tags", lines=3), |
|
title="Camie Tagger (ONNX) – Simple Demo", |
|
description="Upload an anime/manga illustration to get relevant tags predicted by the Camie Tagger model.", |
|
|
|
examples=[["example1.jpg"], ["example2.png"]] |
|
) |
|
|
|
|
|
demo.launch() |
|
|