import gradio as gr import onnxruntime as ort import numpy as np from PIL import Image import json from huggingface_hub import hf_hub_download # Load model and metadata at startup (same as before) MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime" MODEL_FILE = "camie_tagger_initial.onnx" META_FILE = "metadata.json" model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".") meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".") session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) metadata = json.load(open(meta_path, "r", encoding="utf-8")) # Preprocessing function (same as before) def preprocess_image(pil_image: Image.Image) -> np.ndarray: img = pil_image.convert("RGB").resize((512, 512)) arr = np.array(img).astype(np.float32) / 255.0 arr = np.transpose(arr, (2, 0, 1)) arr = np.expand_dims(arr, 0) return arr # Inference function with output format option def tag_image(pil_image: Image.Image, output_format: str) -> str: # Run model inference input_tensor = preprocess_image(pil_image) input_name = session.get_inputs()[0].name initial_logits, refined_logits = session.run(None, {input_name: input_tensor}) probs = 1 / (1 + np.exp(-refined_logits)) probs = probs[0] idx_to_tag = metadata["idx_to_tag"] tag_to_category = metadata.get("tag_to_category", {}) category_thresholds = metadata.get("category_thresholds", {}) default_threshold = 0.325 results_by_cat = {} # to store tags per category (for verbose output) prompt_tags = [] # to store tags for prompt-style output # Collect tags above thresholds for idx, prob in enumerate(probs): tag = idx_to_tag[str(idx)] cat = tag_to_category.get(tag, "unknown") thresh = category_thresholds.get(cat, default_threshold) if float(prob) >= thresh: # add to category dictionary results_by_cat.setdefault(cat, []).append((tag, float(prob))) # add to prompt list prompt_tags.append(tag.replace("_", " ")) if output_format == "Prompt-style Tags": if not prompt_tags: return "No tags predicted." # Join tags with commas (sorted by probability for relevance) # Sort prompt_tags by probability from results_by_cat (for better prompts ordering) prompt_tags.sort(key=lambda t: max([p for (tg, p) in results_by_cat[tag_to_category.get(t.replace(' ', '_'), 'unknown')] if tg == t.replace(' ', '_')]), reverse=True) return ", ".join(prompt_tags) else: # Detailed output if not results_by_cat: return "No tags predicted for this image." lines = [] lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline) for cat, tag_list in results_by_cat.items(): # sort tags in this category by probability descending tag_list.sort(key=lambda x: x[1], reverse=True) lines.append(f"**Category: {cat}** – {len(tag_list)} tags") for tag, prob in tag_list: tag_pretty = tag.replace("_", " ") lines.append(f"- {tag_pretty} (Prob: {prob:.3f})") lines.append("") # blank line between categories return "\n".join(lines) # Build the Gradio Blocks UI demo = gr.Blocks(theme=gr.themes.Soft()) # using a built-in theme for nicer styling with demo: # Header Section gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.") gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*") # Input/Output Section with gr.Row(): # Left column: Image input and format selection with gr.Column(): image_in = gr.Image(type="pil", label="Input Image") format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format") tag_button = gr.Button("🔍 Tag Image") # Right column: Output display with gr.Column(): output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.) # Example images (if available in the repo) gr.Examples( examples=[["example1.jpg"], ["example2.png"]], # Example file paths (ensure these exist in the Space) inputs=image_in, outputs=output_box, fn=tag_image, cache_examples=True ) # Link the button click to the function tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box) # Footer/Info gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   •   *Demo built with Gradio Blocks.*") # Launch the app (automatically handled in Spaces) demo.launch()