AngelBottomless commited on
Commit
648fb46
·
verified ·
1 Parent(s): 7111b78

Upload 3 files

Browse files
Files changed (3) hide show
  1. camie_tagger_initial.onnx +3 -0
  2. infer.py +80 -0
  3. metadata.json +0 -0
camie_tagger_initial.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ed6e97bf389857516873416affedc0572bb15c1a39531db2e8f92dfd5abdf0d
3
+ size 855879045
infer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import json
4
+ from PIL import Image
5
+
6
+ # 1) Load ONNX model
7
+ session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])
8
+
9
+ # 2) Preprocess your image (512x512, etc.)
10
+ def preprocess_image(img_path):
11
+ """
12
+ Loads and resizes an image to 512x512, converts it to float32 [0..1],
13
+ and returns a (1,3,512,512) NumPy array (NCHW format).
14
+ """
15
+ img = Image.open(img_path).convert("RGB").resize((512, 512))
16
+ x = np.array(img).astype(np.float32) / 255.0
17
+ x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
18
+ x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
19
+ return x
20
+
21
+ # Example input
22
+ def inference(input_path):
23
+ input_tensor = preprocess_image(input_path)
24
+
25
+ # 3) Run inference
26
+ input_name = session.get_inputs()[0].name
27
+ outputs = session.run(None, {input_name: input_tensor})
28
+ initial_logits, refined_logits = outputs # shape: (1, 70527) each
29
+
30
+ # 4) Convert logits to probabilities via sigmoid
31
+ refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
32
+
33
+ # 5) Load metadata & retrieve threshold info
34
+ with open("metadata.json", "r", encoding="utf-8") as f:
35
+ metadata = json.load(f)
36
+
37
+ # Dictionary of idx->tag_name, e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
38
+ idx_to_tag = metadata["idx_to_tag"]
39
+
40
+ # Dictionary of tag->category, e.g. { "brown_hair": "character", "landscape": "general", ... }
41
+ tag_to_category = metadata.get("tag_to_category", {})
42
+
43
+ # Dictionary of category->threshold, e.g. { "character": 0.30, "general": 0.325, ... }
44
+ # If not present or incomplete, we'll use a default threshold of 0.325
45
+ category_thresholds = metadata.get("category_thresholds", {})
46
+ default_threshold = 0.325
47
+
48
+ # 6) Collect predictions by category
49
+ # We'll loop through all tags and check if the probability is above the category-specific threshold
50
+ results_by_category = {}
51
+
52
+ num_tags = refined_probs.shape[1] # 70527
53
+ for i in range(num_tags):
54
+ prob = float(refined_probs[0, i]) # get probability for this tag
55
+ tag_name = idx_to_tag[str(i)] # convert index -> tag name (keys in idx_to_tag are strings)
56
+
57
+ # Find category; if not in 'tag_to_category', label it "unknown"
58
+ category = tag_to_category.get(tag_name, "unknown")
59
+
60
+ # Find threshold for this category; fallback to default
61
+ cat_threshold = category_thresholds.get(category, default_threshold)
62
+
63
+ # Check if prob meets or exceeds the threshold
64
+ if prob >= cat_threshold:
65
+ if category not in results_by_category:
66
+ results_by_category[category] = []
67
+ # Store the tag name + its probability
68
+ results_by_category[category].append((tag_name, prob))
69
+
70
+ # 7) Print out the predicted tags category-wise
71
+ print("Predicted Tags by Category:\n")
72
+
73
+ for cat, tags_list in results_by_category.items():
74
+ print(f"Category: {cat} | Predicted {len(tags_list)} tags")
75
+ for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
76
+ print(f" Tag: {tname:30s} Prob: {tprob:.4f}")
77
+ print()
78
+
79
+ if __name__ == "__main__":
80
+ inference("example_image.jpg")
metadata.json ADDED
The diff for this file is too large to render. See raw diff