AngelBottomless commited on
Commit
7ec5b17
Β·
verified Β·
1 Parent(s): 086766e

Thanks to GPT

Browse files
Files changed (1) hide show
  1. app.py +74 -42
app.py CHANGED
@@ -5,65 +5,97 @@ from PIL import Image
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
- # Load the ONNX model and metadata once at startup (optimizes performance)
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
- MODEL_FILE = "camie_tagger_initial.onnx" # using the smaller initial model for speed
11
  META_FILE = "metadata.json"
12
-
13
- # Download model and metadata from HF Hub (cache_dir="." will cache in the Space)
14
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
15
  meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
16
  session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
17
  metadata = json.load(open(meta_path, "r", encoding="utf-8"))
18
-
19
- # Preprocessing: resize image to 512x512 and normalize to match training
20
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
21
  img = pil_image.convert("RGB").resize((512, 512))
22
- arr = np.array(img).astype(np.float32) / 255.0 # scale pixel values to [0,1]
23
- arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
24
- arr = np.expand_dims(arr, 0) # add batch dimension -> (1,3,512,512)
25
  return arr
26
 
27
- # Inference: run the ONNX model and collect tags above threshold
28
- def predict_tags(pil_image: Image.Image) -> str:
29
- # 1. Preprocess image to numpy
30
  input_tensor = preprocess_image(pil_image)
31
- # 2. Run model (both initial and refined logits are output)
32
  input_name = session.get_inputs()[0].name
33
  initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
34
- # 3. Convert logits to probabilities (using sigmoid since multi-label)
35
- probs = 1 / (1 + np.exp(-refined_logits)) # shape (1, 70527)
36
- probs = probs[0] # remove batch dim -> (70527,)
37
- # 4. Thresholding: get tag names for which probability >= category threshold (or default)
38
- idx_to_tag = metadata["idx_to_tag"] # map index -> tag string
39
- tag_to_category = metadata.get("tag_to_category", {}) # map tag -> category
40
- category_thresholds = metadata.get("category_thresholds", {})# category-specific thresholds
41
  default_threshold = 0.325
42
- predicted_tags = []
 
 
43
  for idx, prob in enumerate(probs):
44
  tag = idx_to_tag[str(idx)]
45
  cat = tag_to_category.get(tag, "unknown")
46
- threshold = category_thresholds.get(cat, default_threshold)
47
- if prob >= threshold:
48
- # Include this tag; replace underscores with spaces for readability
49
- predicted_tags.append(tag.replace("_", " "))
50
- # 5. Return tags as comma-separated string
51
- if not predicted_tags:
52
- return "No tags found."
53
- # Join tags, maybe sorted by name or leave unsorted. Here we sort alphabetically for consistency.
54
- predicted_tags.sort()
55
- return ", ".join(predicted_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Create a simple Gradio interface
58
- demo = gr.Interface(
59
- fn=predict_tags,
60
- inputs=gr.Image(type="pil", label="Upload Image"),
61
- outputs=gr.Textbox(label="Predicted Tags", lines=3),
62
- title="Camie Tagger (ONNX) – Simple Demo",
63
- description="Upload an anime/manga illustration to get relevant tags predicted by the Camie Tagger model.",
64
- # You can optionally add example images if available in the Space directory:
65
- examples=[["example1.jpg"], ["example2.png"]] # (filenames should exist in the Space)
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Launch the app (in HF Spaces, just calling demo.launch() is typically not required; the Space will run app automatically)
69
  demo.launch()
 
5
  import json
6
  from huggingface_hub import hf_hub_download
7
 
8
+ # Load model and metadata at startup (same as before)
9
  MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
10
+ MODEL_FILE = "camie_tagger_initial.onnx"
11
  META_FILE = "metadata.json"
 
 
12
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
13
  meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
14
  session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
15
  metadata = json.load(open(meta_path, "r", encoding="utf-8"))
16
+ # Preprocessing function (same as before)
 
17
  def preprocess_image(pil_image: Image.Image) -> np.ndarray:
18
  img = pil_image.convert("RGB").resize((512, 512))
19
+ arr = np.array(img).astype(np.float32) / 255.0
20
+ arr = np.transpose(arr, (2, 0, 1))
21
+ arr = np.expand_dims(arr, 0)
22
  return arr
23
 
24
+ # Inference function with output format option
25
+ def tag_image(pil_image: Image.Image, output_format: str) -> str:
26
+ # Run model inference
27
  input_tensor = preprocess_image(pil_image)
 
28
  input_name = session.get_inputs()[0].name
29
  initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
30
+ probs = 1 / (1 + np.exp(-refined_logits))
31
+ probs = probs[0]
32
+ idx_to_tag = metadata["idx_to_tag"]
33
+ tag_to_category = metadata.get("tag_to_category", {})
34
+ category_thresholds = metadata.get("category_thresholds", {})
 
 
35
  default_threshold = 0.325
36
+ results_by_cat = {} # to store tags per category (for verbose output)
37
+ prompt_tags = [] # to store tags for prompt-style output
38
+ # Collect tags above thresholds
39
  for idx, prob in enumerate(probs):
40
  tag = idx_to_tag[str(idx)]
41
  cat = tag_to_category.get(tag, "unknown")
42
+ thresh = category_thresholds.get(cat, default_threshold)
43
+ if float(prob) >= thresh:
44
+ # add to category dictionary
45
+ results_by_cat.setdefault(cat, []).append((tag, float(prob)))
46
+ # add to prompt list
47
+ prompt_tags.append(tag.replace("_", " "))
48
+ if output_format == "Prompt-style Tags":
49
+ if not prompt_tags:
50
+ return "No tags predicted."
51
+ # Join tags with commas (sorted by probability for relevance)
52
+ # Sort prompt_tags by probability from results_by_cat (for better prompts ordering)
53
+ prompt_tags.sort(key=lambda t: max([p for (tg, p) in results_by_cat[tag_to_category.get(t.replace(' ', '_'), 'unknown')] if tg == t.replace(' ', '_')]), reverse=True)
54
+ return ", ".join(prompt_tags)
55
+ else: # Detailed output
56
+ if not results_by_cat:
57
+ return "No tags predicted for this image."
58
+ lines = []
59
+ lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
60
+ for cat, tag_list in results_by_cat.items():
61
+ # sort tags in this category by probability descending
62
+ tag_list.sort(key=lambda x: x[1], reverse=True)
63
+ lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
64
+ for tag, prob in tag_list:
65
+ tag_pretty = tag.replace("_", " ")
66
+ lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
67
+ lines.append("") # blank line between categories
68
+ return "\n".join(lines)
69
+
70
+ # Build the Gradio Blocks UI
71
+ demo = gr.Blocks(theme=gr.themes.Soft()) # using a built-in theme for nicer styling
72
 
73
+ with demo:
74
+ # Header Section
75
+ gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
76
+ gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
77
+ # Input/Output Section
78
+ with gr.Row():
79
+ # Left column: Image input and format selection
80
+ with gr.Column():
81
+ image_in = gr.Image(type="pil", label="Input Image")
82
+ format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
83
+ tag_button = gr.Button("πŸ” Tag Image")
84
+ # Right column: Output display
85
+ with gr.Column():
86
+ output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
87
+ # Example images (if available in the repo)
88
+ gr.Examples(
89
+ examples=[["example1.jpg"], ["example2.png"]], # Example file paths (ensure these exist in the Space)
90
+ inputs=image_in,
91
+ outputs=output_box,
92
+ fn=tag_image,
93
+ cache_examples=True
94
+ )
95
+ # Link the button click to the function
96
+ tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
97
+ # Footer/Info
98
+ gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   β€’   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   β€’   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   β€’   *Demo built with Gradio Blocks.*")
99
 
100
+ # Launch the app (automatically handled in Spaces)
101
  demo.launch()