AngelBottomless's picture
Thanks to GPT
7ec5b17 verified
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
# Run model inference
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = 1 / (1 + np.exp(-refined_logits))
probs = probs[0]
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
default_threshold = 0.325
results_by_cat = {} # to store tags per category (for verbose output)
prompt_tags = [] # to store tags for prompt-style output
# Collect tags above thresholds
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
thresh = category_thresholds.get(cat, default_threshold)
if float(prob) >= thresh:
# add to category dictionary
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
# add to prompt list
prompt_tags.append(tag.replace("_", " "))
if output_format == "Prompt-style Tags":
if not prompt_tags:
return "No tags predicted."
# Join tags with commas (sorted by probability for relevance)
# Sort prompt_tags by probability from results_by_cat (for better prompts ordering)
prompt_tags.sort(key=lambda t: max([p for (tg, p) in results_by_cat[tag_to_category.get(t.replace(' ', '_'), 'unknown')] if tg == t.replace(' ', '_')]), reverse=True)
return ", ".join(prompt_tags)
else: # Detailed output
if not results_by_cat:
return "No tags predicted for this image."
lines = []
lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
for cat, tag_list in results_by_cat.items():
# sort tags in this category by probability descending
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ")
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme=gr.themes.Soft()) # using a built-in theme for nicer styling
with demo:
# Header Section
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
# Input/Output Section
with gr.Row():
# Left column: Image input and format selection
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("πŸ” Tag Image")
# Right column: Output display
with gr.Column():
output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
# Example images (if available in the repo)
gr.Examples(
examples=[["example1.jpg"], ["example2.png"]], # Example file paths (ensure these exist in the Space)
inputs=image_in,
outputs=output_box,
fn=tag_image,
cache_examples=True
)
# Link the button click to the function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
# Footer/Info
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   β€’   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   β€’   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   β€’   *Demo built with Gradio Blocks.*")
# Launch the app (automatically handled in Spaces)
demo.launch()