Upload 9 files
Browse files- export.py +30 -0
- infer-refined.py +89 -35
- infer.py +139 -97
- model_code.py +956 -0
- model_config.json +9 -0
- model_info_initial_only.json +9 -0
- model_no_flash.py +195 -0
- thresholds.json +170 -0
export.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.models as models
|
3 |
+
from model_code import InitialOnlyImageTagger # Assume model_code.py classes are accessible
|
4 |
+
from safetensors.torch import load_file
|
5 |
+
|
6 |
+
# Load the trained weights (Initial-only model). Adjust path to your weights file.
|
7 |
+
#weights_path = "model_initial_only.pt"
|
8 |
+
safetensors_path = 'model_initial.safetensors'
|
9 |
+
state_dict = load_file(safetensors_path, device='cpu')
|
10 |
+
#state_dict = torch.load(weights_path, map_location="cpu")
|
11 |
+
# Instantiate the model with the same parameters as training
|
12 |
+
model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True) # dataset not needed for forward
|
13 |
+
model.load_state_dict(state_dict)
|
14 |
+
model.eval() # set to evaluation mode
|
15 |
+
|
16 |
+
# Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
|
17 |
+
dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
|
18 |
+
|
19 |
+
# Export to ONNX
|
20 |
+
onnx_path = "camie_tagger_initial_v15.onnx"
|
21 |
+
torch.onnx.export(
|
22 |
+
model, dummy_input, onnx_path,
|
23 |
+
export_params=True, # store the trained parameter weights in the model file
|
24 |
+
opset_version=13, # ONNX opset version (13 is widely supported)
|
25 |
+
do_constant_folding=True, # optimize constant expressions
|
26 |
+
input_names=["input"],
|
27 |
+
output_names=["initial_logits", "refined_logits"], # model.forward returns two outputs (identical for InitialOnly)
|
28 |
+
dynamic_axes={"input": {0: "batch_size"}} # allow variable batch size
|
29 |
+
)
|
30 |
+
print(f"ONNX model saved to: {onnx_path}")
|
infer-refined.py
CHANGED
@@ -42,73 +42,120 @@ def preprocess_image(img_path, target_size=512, keep_aspect=True):
|
|
42 |
arr = np.expand_dims(arr, axis=0)
|
43 |
return arr
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
"""
|
50 |
Loads the ONNX model, runs inference on a list of image paths,
|
51 |
-
and applies
|
52 |
-
|
53 |
Args:
|
54 |
-
img_paths: List of paths to images.
|
55 |
-
onnx_path: Path to the exported ONNX model file.
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
Returns:
|
60 |
-
A list of dicts, each containing:
|
61 |
{
|
62 |
"initial_logits": np.ndarray of shape (N_tags,),
|
63 |
"refined_logits": np.ndarray of shape (N_tags,),
|
64 |
-
"
|
|
|
65 |
...
|
66 |
}
|
67 |
-
one dict per input image.
|
68 |
"""
|
69 |
# 1) Initialize ONNX runtime session
|
70 |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
71 |
-
#
|
72 |
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
|
73 |
|
74 |
# 2) Pre-load metadata
|
75 |
with open(metadata_file, "r", encoding="utf-8") as f:
|
76 |
metadata = json.load(f)
|
77 |
-
idx_to_tag = metadata["idx_to_tag"]
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# 3) Preprocess each image into a batch
|
80 |
batch_tensors = []
|
81 |
for img_path in img_paths:
|
82 |
-
x = preprocess_image(img_path, target_size=
|
83 |
batch_tensors.append(x)
|
84 |
-
# Concatenate along the batch dimension => shape (batch_size, 3,
|
85 |
batch_input = np.concatenate(batch_tensors, axis=0)
|
86 |
|
87 |
# 4) Run inference
|
88 |
-
input_name = session.get_inputs()[0].name
|
89 |
outputs = session.run(None, {input_name: batch_input})
|
90 |
# Typically we get [initial_tags, refined_tags] as output
|
91 |
-
initial_preds, refined_preds = outputs # shapes => (batch_size,
|
92 |
|
93 |
-
# 5)
|
94 |
batch_results = []
|
95 |
for i in range(initial_preds.shape[0]):
|
96 |
-
# Extract one sample's logits
|
97 |
init_logit = initial_preds[i, :] # shape (N_tags,)
|
98 |
ref_logit = refined_preds[i, :] # shape (N_tags,)
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
#
|
101 |
-
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
|
|
105 |
|
106 |
# Build result for this image
|
107 |
result_dict = {
|
108 |
"initial_logits": init_logit,
|
109 |
"refined_logits": ref_logit,
|
110 |
-
"predicted_indices":
|
111 |
-
"predicted_tags":
|
112 |
}
|
113 |
batch_results.append(result_dict)
|
114 |
|
@@ -116,14 +163,21 @@ def onnx_inference(img_paths,
|
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
# Example usage
|
119 |
-
images = ["
|
120 |
-
results = onnx_inference(
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
for i, res in enumerate(results):
|
126 |
print(f"Image: {images[i]}")
|
127 |
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
|
128 |
-
|
129 |
-
|
|
|
|
|
|
42 |
arr = np.expand_dims(arr, axis=0)
|
43 |
return arr
|
44 |
|
45 |
+
# Example input
|
46 |
+
def load_thresholds(threshold_json_path, mode="balanced"):
|
47 |
+
"""
|
48 |
+
Loads thresholds from the given JSON file, using a particular mode
|
49 |
+
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
|
53 |
+
fallback_threshold (float): The overall threshold if category not found
|
54 |
+
"""
|
55 |
+
with open(threshold_json_path, "r", encoding="utf-8") as f:
|
56 |
+
data = json.load(f)
|
57 |
+
|
58 |
+
# The fallback threshold from the "overall" section for the chosen mode
|
59 |
+
fallback_threshold = data["overall"][mode]["threshold"]
|
60 |
+
|
61 |
+
# Build a dict of thresholds keyed by category
|
62 |
+
thresholds_by_category = {}
|
63 |
+
if "categories" in data:
|
64 |
+
for cat_name, cat_modes in data["categories"].items():
|
65 |
+
# If the chosen mode is present for that category, use it;
|
66 |
+
# otherwise fall back to the "overall" threshold.
|
67 |
+
if mode in cat_modes and "threshold" in cat_modes[mode]:
|
68 |
+
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
|
69 |
+
else:
|
70 |
+
thresholds_by_category[cat_name] = fallback_threshold
|
71 |
+
|
72 |
+
return thresholds_by_category, fallback_threshold
|
73 |
+
def onnx_inference(
|
74 |
+
img_paths,
|
75 |
+
onnx_path="camie_refined_no_flash.onnx",
|
76 |
+
metadata_file="metadata.json",
|
77 |
+
threshold_json_path="thresholds.json",
|
78 |
+
mode="balanced",
|
79 |
+
target_size=512,
|
80 |
+
keep_aspect=True
|
81 |
+
):
|
82 |
"""
|
83 |
Loads the ONNX model, runs inference on a list of image paths,
|
84 |
+
and applies category-wise thresholds from threshold.json (per the chosen mode).
|
85 |
+
|
86 |
Args:
|
87 |
+
img_paths : List of paths to images.
|
88 |
+
onnx_path : Path to the exported ONNX model file.
|
89 |
+
metadata_file : Path to metadata.json that contains idx_to_tag, tag_to_category, etc.
|
90 |
+
threshold_json_path : Path to thresholds.json containing category-wise threshold info.
|
91 |
+
mode : "balanced", "high_precision", or "high_recall".
|
92 |
+
target_size : Final size of preprocessed images (512 by default).
|
93 |
+
keep_aspect : If True, preserve aspect ratio when resizing, pad with black.
|
94 |
+
|
95 |
Returns:
|
96 |
+
A list of dicts, one per input image, each containing:
|
97 |
{
|
98 |
"initial_logits": np.ndarray of shape (N_tags,),
|
99 |
"refined_logits": np.ndarray of shape (N_tags,),
|
100 |
+
"predicted_indices": list of tag indices that exceeded threshold,
|
101 |
+
"predicted_tags": list of predicted tag strings,
|
102 |
...
|
103 |
}
|
|
|
104 |
"""
|
105 |
# 1) Initialize ONNX runtime session
|
106 |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
107 |
+
# For GPU usage, you could do e.g.:
|
108 |
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
|
109 |
|
110 |
# 2) Pre-load metadata
|
111 |
with open(metadata_file, "r", encoding="utf-8") as f:
|
112 |
metadata = json.load(f)
|
113 |
+
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
|
114 |
+
tag_to_category = metadata.get("tag_to_category", {})
|
115 |
+
|
116 |
+
# Load thresholds from thresholds.json using the specified mode
|
117 |
+
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
|
118 |
|
119 |
# 3) Preprocess each image into a batch
|
120 |
batch_tensors = []
|
121 |
for img_path in img_paths:
|
122 |
+
x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect)
|
123 |
batch_tensors.append(x)
|
124 |
+
# Concatenate along the batch dimension => shape (batch_size, 3, H, W)
|
125 |
batch_input = np.concatenate(batch_tensors, axis=0)
|
126 |
|
127 |
# 4) Run inference
|
128 |
+
input_name = session.get_inputs()[0].name # typically "image" or "input"
|
129 |
outputs = session.run(None, {input_name: batch_input})
|
130 |
# Typically we get [initial_tags, refined_tags] as output
|
131 |
+
initial_preds, refined_preds = outputs # shapes => (batch_size, N_tags)
|
132 |
|
133 |
+
# 5) Convert logits -> probabilities -> apply category-specific thresholds
|
134 |
batch_results = []
|
135 |
for i in range(initial_preds.shape[0]):
|
|
|
136 |
init_logit = initial_preds[i, :] # shape (N_tags,)
|
137 |
ref_logit = refined_preds[i, :] # shape (N_tags,)
|
138 |
+
ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # shape (N_tags,)
|
139 |
+
|
140 |
+
predicted_indices = []
|
141 |
+
predicted_tags = []
|
142 |
|
143 |
+
# Check each tag against the category threshold
|
144 |
+
for idx in range(ref_logit.shape[0]):
|
145 |
+
tag_name = idx_to_tag[str(idx)] # Convert index->string->tag name
|
146 |
+
category = tag_to_category.get(tag_name, "general") # fallback to "general" if missing
|
147 |
+
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
|
148 |
|
149 |
+
if ref_prob[idx] >= cat_threshold:
|
150 |
+
predicted_indices.append(idx)
|
151 |
+
predicted_tags.append(tag_name)
|
152 |
|
153 |
# Build result for this image
|
154 |
result_dict = {
|
155 |
"initial_logits": init_logit,
|
156 |
"refined_logits": ref_logit,
|
157 |
+
"predicted_indices": predicted_indices,
|
158 |
+
"predicted_tags": predicted_tags,
|
159 |
}
|
160 |
batch_results.append(result_dict)
|
161 |
|
|
|
163 |
|
164 |
if __name__ == "__main__":
|
165 |
# Example usage
|
166 |
+
images = ["images.png"]
|
167 |
+
results = onnx_inference(
|
168 |
+
img_paths=images,
|
169 |
+
onnx_path="camie_refined_no_flash_v15.onnx",
|
170 |
+
metadata_file="metadata.json",
|
171 |
+
threshold_json_path="thresholds.json",
|
172 |
+
mode="balanced", # or "balanced", "high_precision"
|
173 |
+
target_size=512,
|
174 |
+
keep_aspect=True
|
175 |
+
)
|
176 |
|
177 |
for i, res in enumerate(results):
|
178 |
print(f"Image: {images[i]}")
|
179 |
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
|
180 |
+
# Show first 10 predicted tags (if available)
|
181 |
+
sample_tags = res['predicted_tags']
|
182 |
+
print(" Sample predicted tags:", sample_tags)
|
183 |
+
print()
|
infer.py
CHANGED
@@ -1,98 +1,140 @@
|
|
1 |
-
import onnxruntime as ort
|
2 |
-
import numpy as np
|
3 |
-
import json
|
4 |
-
from PIL import Image
|
5 |
-
|
6 |
-
# 1) Load ONNX model
|
7 |
-
session = ort.InferenceSession("
|
8 |
-
|
9 |
-
# 2) Preprocess your image (512x512, etc.)
|
10 |
-
def preprocess_image(img_path):
|
11 |
-
"""
|
12 |
-
Loads and resizes an image to 512x512, converts it to float32 [0..1],
|
13 |
-
and returns a (1,3,512,512) NumPy array (NCHW format).
|
14 |
-
"""
|
15 |
-
img = Image.open(img_path).convert("RGB").resize((512, 512))
|
16 |
-
x = np.array(img).astype(np.float32) / 255.0
|
17 |
-
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
18 |
-
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
|
19 |
-
return x
|
20 |
-
|
21 |
-
# Example input
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
print(result)
|
|
|
1 |
+
import onnxruntime as ort
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
# 1) Load ONNX model
|
7 |
+
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
|
8 |
+
|
9 |
+
# 2) Preprocess your image (512x512, etc.)
|
10 |
+
def preprocess_image(img_path):
|
11 |
+
"""
|
12 |
+
Loads and resizes an image to 512x512, converts it to float32 [0..1],
|
13 |
+
and returns a (1,3,512,512) NumPy array (NCHW format).
|
14 |
+
"""
|
15 |
+
img = Image.open(img_path).convert("RGB").resize((512, 512))
|
16 |
+
x = np.array(img).astype(np.float32) / 255.0
|
17 |
+
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
18 |
+
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
|
19 |
+
return x
|
20 |
+
|
21 |
+
# Example input
|
22 |
+
def load_thresholds(threshold_json_path, mode="balanced"):
|
23 |
+
"""
|
24 |
+
Loads thresholds from the given JSON file, using a particular mode
|
25 |
+
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
|
29 |
+
fallback_threshold (float): The overall threshold if category not found
|
30 |
+
"""
|
31 |
+
with open(threshold_json_path, "r", encoding="utf-8") as f:
|
32 |
+
data = json.load(f)
|
33 |
+
|
34 |
+
# The fallback threshold from the "overall" section for the chosen mode
|
35 |
+
fallback_threshold = data["overall"][mode]["threshold"]
|
36 |
+
|
37 |
+
# Build a dict of thresholds keyed by category
|
38 |
+
thresholds_by_category = {}
|
39 |
+
if "categories" in data:
|
40 |
+
for cat_name, cat_modes in data["categories"].items():
|
41 |
+
# If the chosen mode is present for that category, use it;
|
42 |
+
# otherwise fall back to the "overall" threshold.
|
43 |
+
if mode in cat_modes and "threshold" in cat_modes[mode]:
|
44 |
+
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
|
45 |
+
else:
|
46 |
+
thresholds_by_category[cat_name] = fallback_threshold
|
47 |
+
|
48 |
+
return thresholds_by_category, fallback_threshold
|
49 |
+
|
50 |
+
def inference(
|
51 |
+
input_path,
|
52 |
+
output_format="verbose",
|
53 |
+
mode="balanced",
|
54 |
+
threshold_json_path="thresholds.json",
|
55 |
+
metadata_path="metadata.json"
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Run inference on an image using the loaded ONNX model, then apply
|
59 |
+
category-wise thresholds from `threshold.json` for the chosen mode.
|
60 |
+
|
61 |
+
Arguments:
|
62 |
+
input_path (str) : Path to the image file for inference.
|
63 |
+
output_format (str) : Either "verbose" or "as_prompt".
|
64 |
+
mode (str) : "balanced", "high_precision", or "high_recall"
|
65 |
+
threshold_json_path (str) : Path to the JSON file with category thresholds.
|
66 |
+
metadata_path (str) : Path to the metadata JSON file with category info.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
str: The predicted tags in either verbose or comma-separated format.
|
70 |
+
"""
|
71 |
+
# 1) Preprocess
|
72 |
+
input_tensor = preprocess_image(input_path)
|
73 |
+
|
74 |
+
# 2) Run inference
|
75 |
+
input_name = session.get_inputs()[0].name
|
76 |
+
outputs = session.run(None, {input_name: input_tensor})
|
77 |
+
initial_logits, refined_logits = outputs # shape: (1, 70527) each
|
78 |
+
|
79 |
+
# 3) Convert logits to probabilities
|
80 |
+
refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
|
81 |
+
|
82 |
+
# 4) Load metadata & retrieve threshold info
|
83 |
+
with open(metadata_path, "r", encoding="utf-8") as f:
|
84 |
+
metadata = json.load(f)
|
85 |
+
|
86 |
+
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
|
87 |
+
tag_to_category = metadata.get("tag_to_category", {})
|
88 |
+
# Load thresholds from threshold.json using the specified mode
|
89 |
+
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
|
90 |
+
|
91 |
+
|
92 |
+
# 5) Collect predictions by category
|
93 |
+
results_by_category = {}
|
94 |
+
num_tags = refined_probs.shape[1]
|
95 |
+
|
96 |
+
for i in range(num_tags):
|
97 |
+
prob = float(refined_probs[0, i])
|
98 |
+
tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
|
99 |
+
category = tag_to_category.get(tag_name, "general")
|
100 |
+
|
101 |
+
# Determine the threshold to use for this category
|
102 |
+
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
|
103 |
+
|
104 |
+
if prob >= cat_threshold:
|
105 |
+
if category not in results_by_category:
|
106 |
+
results_by_category[category] = []
|
107 |
+
results_by_category[category].append((tag_name, prob))
|
108 |
+
|
109 |
+
# 6) Depending on output_format, produce different return strings
|
110 |
+
if output_format == "as_prompt":
|
111 |
+
# Flatten all predicted tags across categories
|
112 |
+
all_predicted_tags = []
|
113 |
+
for cat, tags_list in results_by_category.items():
|
114 |
+
# We only need the tag name in as_prompt format
|
115 |
+
for tname, tprob in tags_list:
|
116 |
+
# convert underscores to spaces
|
117 |
+
tag_name_spaces = tname.replace("_", " ")
|
118 |
+
all_predicted_tags.append(tag_name_spaces)
|
119 |
+
|
120 |
+
# Create a comma-separated string
|
121 |
+
prompt_string = ", ".join(all_predicted_tags)
|
122 |
+
return prompt_string
|
123 |
+
|
124 |
+
else: # "verbose"
|
125 |
+
# We'll build a multiline string describing the predictions
|
126 |
+
lines = []
|
127 |
+
lines.append("Predicted Tags by Category:\n")
|
128 |
+
for cat, tags_list in results_by_category.items():
|
129 |
+
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
|
130 |
+
# Sort descending by probability
|
131 |
+
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
|
132 |
+
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
|
133 |
+
lines.append("") # blank line after each category
|
134 |
+
# Join lines with newlines
|
135 |
+
verbose_output = "\n".join(lines)
|
136 |
+
return verbose_output
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
result = inference("", output_format="as_prompt")
|
140 |
print(result)
|
model_code.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
5 |
+
from PIL import Image
|
6 |
+
from typing import Optional
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
|
11 |
+
class InitialOnlyImageTagger(nn.Module):
|
12 |
+
"""
|
13 |
+
A lightweight version of ImageTagger that only includes the backbone and initial classifier.
|
14 |
+
This model uses significantly less VRAM than the full model.
|
15 |
+
"""
|
16 |
+
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
17 |
+
dropout=0.1, pretrained=True):
|
18 |
+
super().__init__()
|
19 |
+
# Debug and stats flags
|
20 |
+
self._flags = {
|
21 |
+
'debug': False,
|
22 |
+
'model_stats': False
|
23 |
+
}
|
24 |
+
|
25 |
+
# Core model config
|
26 |
+
self.dataset = dataset
|
27 |
+
self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
|
28 |
+
|
29 |
+
# Initialize backbone
|
30 |
+
if model_name == 'efficientnet_v2_l':
|
31 |
+
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
32 |
+
self.backbone = efficientnet_v2_l(weights=weights)
|
33 |
+
self.backbone.classifier = nn.Identity()
|
34 |
+
|
35 |
+
# Spatial pooling only - no projection
|
36 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
37 |
+
|
38 |
+
# Initial tag prediction with bottleneck
|
39 |
+
self.initial_classifier = nn.Sequential(
|
40 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
41 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
42 |
+
nn.GELU(),
|
43 |
+
nn.Dropout(dropout),
|
44 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
45 |
+
nn.LayerNorm(self.embedding_dim),
|
46 |
+
nn.GELU(),
|
47 |
+
nn.Linear(self.embedding_dim, total_tags)
|
48 |
+
)
|
49 |
+
|
50 |
+
# Temperature scaling
|
51 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
52 |
+
|
53 |
+
@property
|
54 |
+
def debug(self):
|
55 |
+
return self._flags['debug']
|
56 |
+
|
57 |
+
@debug.setter
|
58 |
+
def debug(self, value):
|
59 |
+
self._flags['debug'] = value
|
60 |
+
|
61 |
+
@property
|
62 |
+
def model_stats(self):
|
63 |
+
return self._flags['model_stats']
|
64 |
+
|
65 |
+
@model_stats.setter
|
66 |
+
def model_stats(self, value):
|
67 |
+
self._flags['model_stats'] = value
|
68 |
+
|
69 |
+
def preprocess_image(self, image_path, image_size=512):
|
70 |
+
"""Process an image for inference using same preprocessing as training"""
|
71 |
+
if not os.path.exists(image_path):
|
72 |
+
raise ValueError(f"Image not found at path: {image_path}")
|
73 |
+
|
74 |
+
# Initialize the same transform used during training
|
75 |
+
transform = transforms.Compose([
|
76 |
+
transforms.ToTensor(),
|
77 |
+
])
|
78 |
+
|
79 |
+
try:
|
80 |
+
with Image.open(image_path) as img:
|
81 |
+
# Convert RGBA or Palette images to RGB
|
82 |
+
if img.mode in ('RGBA', 'P'):
|
83 |
+
img = img.convert('RGB')
|
84 |
+
|
85 |
+
# Get original dimensions
|
86 |
+
width, height = img.size
|
87 |
+
aspect_ratio = width / height
|
88 |
+
|
89 |
+
# Calculate new dimensions to maintain aspect ratio
|
90 |
+
if aspect_ratio > 1:
|
91 |
+
new_width = image_size
|
92 |
+
new_height = int(new_width / aspect_ratio)
|
93 |
+
else:
|
94 |
+
new_height = image_size
|
95 |
+
new_width = int(new_height * aspect_ratio)
|
96 |
+
|
97 |
+
# Resize with LANCZOS filter
|
98 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
99 |
+
|
100 |
+
# Create new image with padding
|
101 |
+
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
102 |
+
paste_x = (image_size - new_width) // 2
|
103 |
+
paste_y = (image_size - new_height) // 2
|
104 |
+
new_image.paste(img, (paste_x, paste_y))
|
105 |
+
|
106 |
+
# Apply transforms (without normalization)
|
107 |
+
img_tensor = transform(new_image)
|
108 |
+
return img_tensor
|
109 |
+
except Exception as e:
|
110 |
+
raise Exception(f"Error processing {image_path}: {str(e)}")
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
"""Forward pass with only the initial predictions"""
|
114 |
+
# Image Feature Extraction
|
115 |
+
features = self.backbone.features(x)
|
116 |
+
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
117 |
+
|
118 |
+
# Initial Tag Predictions
|
119 |
+
initial_logits = self.initial_classifier(features)
|
120 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
121 |
+
|
122 |
+
# For API compatibility with the full model, return the same predictions twice
|
123 |
+
return initial_preds, initial_preds
|
124 |
+
|
125 |
+
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
126 |
+
"""
|
127 |
+
Run inference on an image with support for category-specific thresholds.
|
128 |
+
"""
|
129 |
+
# Preprocess the image
|
130 |
+
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
131 |
+
|
132 |
+
# Move to the same device as model and convert to half precision
|
133 |
+
device = next(self.parameters()).device
|
134 |
+
dtype = next(self.parameters()).dtype # Match model's precision
|
135 |
+
img_tensor = img_tensor.to(device, dtype=dtype)
|
136 |
+
|
137 |
+
# Run inference
|
138 |
+
with torch.no_grad():
|
139 |
+
initial_preds, _ = self.forward(img_tensor)
|
140 |
+
|
141 |
+
# Apply sigmoid to get probabilities
|
142 |
+
initial_probs = torch.sigmoid(initial_preds)
|
143 |
+
|
144 |
+
# Apply thresholds
|
145 |
+
if category_thresholds:
|
146 |
+
# Create binary prediction tensors
|
147 |
+
initial_binary = torch.zeros_like(initial_probs)
|
148 |
+
|
149 |
+
# Apply thresholds by category
|
150 |
+
for category, cat_threshold in category_thresholds.items():
|
151 |
+
# Create a mask for tags in this category
|
152 |
+
category_mask = torch.zeros_like(initial_probs, dtype=torch.bool)
|
153 |
+
|
154 |
+
# Find indices for this category
|
155 |
+
for tag_idx in range(initial_probs.size(-1)):
|
156 |
+
try:
|
157 |
+
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
158 |
+
if tag_category == category:
|
159 |
+
category_mask[:, tag_idx] = True
|
160 |
+
except:
|
161 |
+
continue
|
162 |
+
|
163 |
+
# Apply threshold only to tags in this category
|
164 |
+
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
165 |
+
initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
166 |
+
|
167 |
+
predictions = initial_binary
|
168 |
+
else:
|
169 |
+
# Use the same threshold for all tags
|
170 |
+
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
171 |
+
predictions = (initial_probs >= threshold_tensor).to(dtype)
|
172 |
+
|
173 |
+
# Return the same probabilities for both initial and refined for API compatibility
|
174 |
+
return {
|
175 |
+
'initial_probabilities': initial_probs,
|
176 |
+
'refined_probabilities': initial_probs, # Same as initial for compatibility
|
177 |
+
'predictions': predictions
|
178 |
+
}
|
179 |
+
|
180 |
+
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
181 |
+
"""
|
182 |
+
Convert model predictions to human-readable tags grouped by category.
|
183 |
+
"""
|
184 |
+
# Get non-zero predictions
|
185 |
+
if predictions.dim() > 1:
|
186 |
+
predictions = predictions[0] # Remove batch dimension
|
187 |
+
|
188 |
+
# Get indices of positive predictions
|
189 |
+
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
190 |
+
|
191 |
+
# Group by category
|
192 |
+
result = {}
|
193 |
+
for idx in indices:
|
194 |
+
tag_name, category = self.dataset.get_tag_info(idx)
|
195 |
+
|
196 |
+
if category not in result:
|
197 |
+
result[category] = []
|
198 |
+
|
199 |
+
if include_probabilities:
|
200 |
+
prob = predictions[idx].item()
|
201 |
+
result[category].append((tag_name, prob))
|
202 |
+
else:
|
203 |
+
result[category].append(tag_name)
|
204 |
+
|
205 |
+
# Sort tags by probability within each category
|
206 |
+
if include_probabilities:
|
207 |
+
for category in result:
|
208 |
+
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
209 |
+
|
210 |
+
return result
|
211 |
+
|
212 |
+
class FlashAttention(nn.Module):
|
213 |
+
def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True):
|
214 |
+
super().__init__()
|
215 |
+
self.dim = dim
|
216 |
+
self.num_heads = num_heads
|
217 |
+
self.dropout = dropout
|
218 |
+
self.batch_first = batch_first
|
219 |
+
self.head_dim = dim // num_heads
|
220 |
+
assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
|
221 |
+
|
222 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
223 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
224 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
225 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
226 |
+
|
227 |
+
for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
|
228 |
+
nn.init.xavier_uniform_(proj.weight, gain=0.1)
|
229 |
+
|
230 |
+
self.scale = self.head_dim ** -0.5
|
231 |
+
self.debug = False
|
232 |
+
|
233 |
+
def _debug_print(self, name, tensor):
|
234 |
+
"""Debug helper"""
|
235 |
+
if self.debug:
|
236 |
+
print(f"\n{name}:")
|
237 |
+
print(f"Shape: {tensor.shape}")
|
238 |
+
print(f"Device: {tensor.device}")
|
239 |
+
print(f"Dtype: {tensor.dtype}")
|
240 |
+
if tensor.is_floating_point():
|
241 |
+
with torch.no_grad():
|
242 |
+
print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
243 |
+
print(f"Mean: {tensor.mean().item():.3f}")
|
244 |
+
print(f"Std: {tensor.std().item():.3f}")
|
245 |
+
|
246 |
+
def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor:
|
247 |
+
"""Reshape input tensor for flash attention format"""
|
248 |
+
batch_size, seq_len, _ = x.size()
|
249 |
+
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
250 |
+
x = x.transpose(1, 2) # [B, H, S, D]
|
251 |
+
return x.contiguous()
|
252 |
+
|
253 |
+
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None,
|
254 |
+
value: Optional[torch.Tensor] = None,
|
255 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
256 |
+
"""Forward pass with flash attention"""
|
257 |
+
if self.debug:
|
258 |
+
print("\nFlashAttention Forward Pass")
|
259 |
+
|
260 |
+
batch_size = query.size(0)
|
261 |
+
|
262 |
+
# Use query as key/value if not provided
|
263 |
+
key = query if key is None else key
|
264 |
+
value = query if value is None else value
|
265 |
+
|
266 |
+
# Project inputs
|
267 |
+
q = self.q_proj(query)
|
268 |
+
k = self.k_proj(key)
|
269 |
+
v = self.v_proj(value)
|
270 |
+
|
271 |
+
if self.debug:
|
272 |
+
self._debug_print("Query before reshape", q)
|
273 |
+
|
274 |
+
# Reshape for attention [B, H, S, D]
|
275 |
+
q = self._reshape_for_flash(q)
|
276 |
+
k = self._reshape_for_flash(k)
|
277 |
+
v = self._reshape_for_flash(v)
|
278 |
+
|
279 |
+
if self.debug:
|
280 |
+
self._debug_print("Query after reshape", q)
|
281 |
+
|
282 |
+
# Handle masking
|
283 |
+
if mask is not None:
|
284 |
+
# First convert mask to proper shape based on input dimensionality
|
285 |
+
if mask.dim() == 2: # [B, S]
|
286 |
+
mask = mask.view(batch_size, 1, -1, 1)
|
287 |
+
elif mask.dim() == 3: # [B, S, S]
|
288 |
+
mask = mask.view(batch_size, 1, mask.size(1), mask.size(2))
|
289 |
+
elif mask.dim() == 5: # [B, 1, S, S, S]
|
290 |
+
mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3))
|
291 |
+
|
292 |
+
# Ensure mask is float16 if we're using float16
|
293 |
+
mask = mask.to(q.dtype)
|
294 |
+
|
295 |
+
if self.debug:
|
296 |
+
self._debug_print("Prepared mask", mask)
|
297 |
+
print(f"q shape: {q.shape}, mask shape: {mask.shape}")
|
298 |
+
|
299 |
+
# Create attention mask that covers the full sequence length
|
300 |
+
seq_len = q.size(2)
|
301 |
+
if mask.size(-1) != seq_len:
|
302 |
+
# Pad or trim mask to match sequence length
|
303 |
+
new_mask = torch.zeros(batch_size, 1, seq_len, seq_len,
|
304 |
+
device=mask.device, dtype=mask.dtype)
|
305 |
+
min_len = min(seq_len, mask.size(-1))
|
306 |
+
new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len]
|
307 |
+
mask = new_mask
|
308 |
+
|
309 |
+
# Create key padding mask
|
310 |
+
key_padding_mask = mask.squeeze(1).sum(-1) > 0
|
311 |
+
key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1)
|
312 |
+
|
313 |
+
# Apply the key padding mask
|
314 |
+
k = k * key_padding_mask
|
315 |
+
v = v * key_padding_mask
|
316 |
+
|
317 |
+
if self.debug:
|
318 |
+
self._debug_print("Query before attention", q)
|
319 |
+
self._debug_print("Key before attention", k)
|
320 |
+
self._debug_print("Value before attention", v)
|
321 |
+
|
322 |
+
# Run flash attention
|
323 |
+
dropout_p = self.dropout if self.training else 0.0
|
324 |
+
output = flash_attn_func(
|
325 |
+
q, k, v,
|
326 |
+
dropout_p=dropout_p,
|
327 |
+
softmax_scale=self.scale,
|
328 |
+
causal=False
|
329 |
+
)
|
330 |
+
|
331 |
+
if self.debug:
|
332 |
+
self._debug_print("Output after attention", output)
|
333 |
+
|
334 |
+
# Reshape output [B, H, S, D] -> [B, S, H, D] -> [B, S, D]
|
335 |
+
output = output.transpose(1, 2).contiguous()
|
336 |
+
output = output.view(batch_size, -1, self.dim)
|
337 |
+
|
338 |
+
# Final projection
|
339 |
+
output = self.out_proj(output)
|
340 |
+
|
341 |
+
if self.debug:
|
342 |
+
self._debug_print("Final output", output)
|
343 |
+
|
344 |
+
return output
|
345 |
+
|
346 |
+
class OptimizedTagEmbedding(nn.Module):
|
347 |
+
def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1):
|
348 |
+
super().__init__()
|
349 |
+
# Single shared embedding for all tags
|
350 |
+
self.embedding = nn.Embedding(num_tags, embedding_dim)
|
351 |
+
self.attention = FlashAttention(embedding_dim, num_heads, dropout)
|
352 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
353 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
354 |
+
|
355 |
+
# Single importance weighting for all tags
|
356 |
+
self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1)
|
357 |
+
|
358 |
+
# Projection layers for unified tag context
|
359 |
+
self.context_proj = nn.Sequential(
|
360 |
+
nn.Linear(embedding_dim, embedding_dim * 2),
|
361 |
+
nn.LayerNorm(embedding_dim * 2),
|
362 |
+
nn.GELU(),
|
363 |
+
nn.Dropout(dropout),
|
364 |
+
nn.Linear(embedding_dim * 2, embedding_dim),
|
365 |
+
nn.LayerNorm(embedding_dim)
|
366 |
+
)
|
367 |
+
|
368 |
+
self.importance_scale = nn.Parameter(torch.tensor(0.1))
|
369 |
+
self.context_scale = nn.Parameter(torch.tensor(1.0))
|
370 |
+
self.debug = False
|
371 |
+
|
372 |
+
def _debug_print(self, name, tensor, extra_info=None):
|
373 |
+
"""Memory efficient debug printing with type handling"""
|
374 |
+
if self.debug:
|
375 |
+
print(f"\n{name}:")
|
376 |
+
print(f"- Shape: {tensor.shape}")
|
377 |
+
if isinstance(tensor, torch.Tensor):
|
378 |
+
with torch.no_grad():
|
379 |
+
print(f"- Device: {tensor.device}")
|
380 |
+
print(f"- Dtype: {tensor.dtype}")
|
381 |
+
|
382 |
+
# Convert to float32 for statistics if needed
|
383 |
+
if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
|
384 |
+
calc_tensor = tensor.float()
|
385 |
+
else:
|
386 |
+
calc_tensor = tensor
|
387 |
+
|
388 |
+
try:
|
389 |
+
min_val = calc_tensor.min().item()
|
390 |
+
max_val = calc_tensor.max().item()
|
391 |
+
mean_val = calc_tensor.mean().item()
|
392 |
+
std_val = calc_tensor.std().item()
|
393 |
+
norm_val = torch.norm(calc_tensor).item()
|
394 |
+
|
395 |
+
print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]")
|
396 |
+
print(f"- Mean: {mean_val:.3f}")
|
397 |
+
print(f"- Std: {std_val:.3f}")
|
398 |
+
print(f"- L2 Norm: {norm_val:.3f}")
|
399 |
+
|
400 |
+
if extra_info:
|
401 |
+
print(f"- Additional info: {extra_info}")
|
402 |
+
except Exception as e:
|
403 |
+
print(f"- Could not compute statistics: {str(e)}")
|
404 |
+
|
405 |
+
def _debug_tensor(self, name, tensor):
|
406 |
+
"""Debug helper with dtype-specific analysis"""
|
407 |
+
if self.debug and isinstance(tensor, torch.Tensor):
|
408 |
+
print(f"\n{name}:")
|
409 |
+
print(f"- Shape: {tensor.shape}")
|
410 |
+
print(f"- Device: {tensor.device}")
|
411 |
+
print(f"- Dtype: {tensor.dtype}")
|
412 |
+
with torch.no_grad():
|
413 |
+
has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False
|
414 |
+
has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False
|
415 |
+
print(f"- Contains NaN: {has_nan}")
|
416 |
+
print(f"- Contains Inf: {has_inf}")
|
417 |
+
|
418 |
+
# Different stats for different dtypes
|
419 |
+
if tensor.is_floating_point():
|
420 |
+
print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
421 |
+
print(f"- Mean: {tensor.mean().item():.3f}")
|
422 |
+
print(f"- Std: {tensor.std().item():.3f}")
|
423 |
+
else:
|
424 |
+
# For integer tensors
|
425 |
+
print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]")
|
426 |
+
print(f"- Unique values: {tensor.unique().numel()}")
|
427 |
+
|
428 |
+
def _process_category(self, indices, masks):
|
429 |
+
"""Process a single category of tags"""
|
430 |
+
# Get embeddings for this category
|
431 |
+
embeddings = self.embedding(indices)
|
432 |
+
|
433 |
+
if self.debug:
|
434 |
+
self._debug_tensor("Category embeddings", embeddings)
|
435 |
+
|
436 |
+
# Apply importance weights
|
437 |
+
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
438 |
+
importance = torch.clamp(importance, min=0.01, max=10.0)
|
439 |
+
importance_weights = importance[indices].unsqueeze(-1)
|
440 |
+
|
441 |
+
# Apply and normalize
|
442 |
+
embeddings = embeddings * importance_weights
|
443 |
+
embeddings = self.norm1(embeddings)
|
444 |
+
|
445 |
+
# Apply attention if we have more than one tag
|
446 |
+
if embeddings.size(1) > 1:
|
447 |
+
if masks is not None:
|
448 |
+
attention_mask = torch.einsum('bi,bj->bij', masks, masks)
|
449 |
+
attended = self.attention(embeddings, mask=attention_mask)
|
450 |
+
else:
|
451 |
+
attended = self.attention(embeddings)
|
452 |
+
embeddings = self.norm2(attended)
|
453 |
+
|
454 |
+
# Pool embeddings with masking
|
455 |
+
if masks is not None:
|
456 |
+
masked_embeddings = embeddings * masks.unsqueeze(-1)
|
457 |
+
pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
458 |
+
else:
|
459 |
+
pooled = embeddings.mean(dim=1)
|
460 |
+
|
461 |
+
return pooled, embeddings
|
462 |
+
|
463 |
+
def forward(self, tag_indices_dict, tag_masks_dict=None):
|
464 |
+
"""
|
465 |
+
Process all tags in a unified embedding space
|
466 |
+
Args:
|
467 |
+
tag_indices_dict: dict of {category: tensor of indices}
|
468 |
+
tag_masks_dict: dict of {category: tensor of masks}
|
469 |
+
"""
|
470 |
+
if self.debug:
|
471 |
+
print("\nOptimizedTagEmbedding Forward Pass")
|
472 |
+
|
473 |
+
# Concatenate all indices and masks
|
474 |
+
all_indices = []
|
475 |
+
all_masks = []
|
476 |
+
batch_size = None
|
477 |
+
|
478 |
+
for category, indices in tag_indices_dict.items():
|
479 |
+
if batch_size is None:
|
480 |
+
batch_size = indices.size(0)
|
481 |
+
all_indices.append(indices)
|
482 |
+
if tag_masks_dict:
|
483 |
+
all_masks.append(tag_masks_dict[category])
|
484 |
+
|
485 |
+
# Stack along sequence dimension
|
486 |
+
combined_indices = torch.cat(all_indices, dim=1) # [B, total_seq_len]
|
487 |
+
if tag_masks_dict:
|
488 |
+
combined_masks = torch.cat(all_masks, dim=1) # [B, total_seq_len]
|
489 |
+
|
490 |
+
if self.debug:
|
491 |
+
self._debug_tensor("Combined indices", combined_indices)
|
492 |
+
if tag_masks_dict:
|
493 |
+
self._debug_tensor("Combined masks", combined_masks)
|
494 |
+
|
495 |
+
# Get embeddings for all tags using shared embedding
|
496 |
+
embeddings = self.embedding(combined_indices) # [B, total_seq_len, D]
|
497 |
+
|
498 |
+
if self.debug:
|
499 |
+
self._debug_tensor("Base embeddings", embeddings)
|
500 |
+
|
501 |
+
# Apply unified importance weighting
|
502 |
+
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
503 |
+
importance = torch.clamp(importance, min=0.01, max=10.0)
|
504 |
+
importance_weights = importance[combined_indices].unsqueeze(-1)
|
505 |
+
|
506 |
+
# Apply and normalize importance weights
|
507 |
+
embeddings = embeddings * importance_weights
|
508 |
+
embeddings = self.norm1(embeddings)
|
509 |
+
|
510 |
+
if self.debug:
|
511 |
+
self._debug_tensor("Weighted embeddings", embeddings)
|
512 |
+
|
513 |
+
# Apply attention across all tags together
|
514 |
+
if tag_masks_dict:
|
515 |
+
attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks)
|
516 |
+
attended = self.attention(embeddings, mask=attention_mask)
|
517 |
+
else:
|
518 |
+
attended = self.attention(embeddings)
|
519 |
+
|
520 |
+
attended = self.norm2(attended)
|
521 |
+
|
522 |
+
if self.debug:
|
523 |
+
self._debug_tensor("Attended embeddings", attended)
|
524 |
+
|
525 |
+
# Global pooling with masking
|
526 |
+
if tag_masks_dict:
|
527 |
+
masked_embeddings = attended * combined_masks.unsqueeze(-1)
|
528 |
+
tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
529 |
+
else:
|
530 |
+
tag_context = attended.mean(dim=1)
|
531 |
+
|
532 |
+
# Project and scale context
|
533 |
+
tag_context = self.context_proj(tag_context)
|
534 |
+
context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0)
|
535 |
+
tag_context = tag_context * context_scale
|
536 |
+
|
537 |
+
if self.debug:
|
538 |
+
self._debug_tensor("Final tag context", tag_context)
|
539 |
+
|
540 |
+
return tag_context, attended
|
541 |
+
|
542 |
+
class TagDataset:
|
543 |
+
"""Lightweight dataset wrapper for inference only"""
|
544 |
+
def __init__(self, total_tags, idx_to_tag, tag_to_category):
|
545 |
+
self.total_tags = total_tags
|
546 |
+
self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()}
|
547 |
+
self.tag_to_category = tag_to_category
|
548 |
+
|
549 |
+
def get_tag_info(self, idx):
|
550 |
+
"""Get tag name and category for a given index"""
|
551 |
+
tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}")
|
552 |
+
category = self.tag_to_category.get(tag_name, "general")
|
553 |
+
return tag_name, category
|
554 |
+
|
555 |
+
class ImageTagger(nn.Module):
|
556 |
+
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
557 |
+
num_heads=16, dropout=0.1, pretrained=True,
|
558 |
+
tag_context_size=256):
|
559 |
+
super().__init__()
|
560 |
+
# Debug and stats flags
|
561 |
+
self._flags = {
|
562 |
+
'debug': False,
|
563 |
+
'model_stats': False
|
564 |
+
}
|
565 |
+
|
566 |
+
# Core model config
|
567 |
+
self.dataset = dataset
|
568 |
+
self.tag_context_size = tag_context_size
|
569 |
+
self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
|
570 |
+
|
571 |
+
# Initialize backbone
|
572 |
+
if model_name == 'efficientnet_v2_l':
|
573 |
+
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
574 |
+
self.backbone = efficientnet_v2_l(weights=weights)
|
575 |
+
self.backbone.classifier = nn.Identity()
|
576 |
+
|
577 |
+
# Spatial pooling only - no projection
|
578 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
579 |
+
|
580 |
+
# Initial tag prediction with bottleneck
|
581 |
+
self.initial_classifier = nn.Sequential(
|
582 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
583 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
584 |
+
nn.GELU(),
|
585 |
+
nn.Dropout(dropout),
|
586 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
587 |
+
nn.LayerNorm(self.embedding_dim),
|
588 |
+
nn.GELU(),
|
589 |
+
nn.Linear(self.embedding_dim, total_tags)
|
590 |
+
)
|
591 |
+
|
592 |
+
# Tag embeddings at full dimension
|
593 |
+
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
594 |
+
self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
595 |
+
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
596 |
+
|
597 |
+
# Improved cross attention projection
|
598 |
+
self.cross_proj = nn.Sequential(
|
599 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
600 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
601 |
+
nn.GELU(),
|
602 |
+
nn.Dropout(dropout),
|
603 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
604 |
+
)
|
605 |
+
|
606 |
+
# Cross attention at full dimension
|
607 |
+
self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
608 |
+
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
609 |
+
|
610 |
+
# Refined classifier with improved bottleneck
|
611 |
+
self.refined_classifier = nn.Sequential(
|
612 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2), # Doubled input size for residual
|
613 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
614 |
+
nn.GELU(),
|
615 |
+
nn.Dropout(dropout),
|
616 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
617 |
+
nn.LayerNorm(self.embedding_dim),
|
618 |
+
nn.GELU(),
|
619 |
+
nn.Linear(self.embedding_dim, total_tags)
|
620 |
+
)
|
621 |
+
|
622 |
+
# Temperature scaling
|
623 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
624 |
+
|
625 |
+
def _get_selected_tags(self, logits):
|
626 |
+
"""Select top-K tags based on prediction confidence"""
|
627 |
+
# Apply sigmoid to get probabilities
|
628 |
+
probs = torch.sigmoid(logits)
|
629 |
+
|
630 |
+
# Get top-K predictions for each image in batch
|
631 |
+
batch_size = logits.size(0)
|
632 |
+
topk_values, topk_indices = torch.topk(
|
633 |
+
probs, k=self.tag_context_size, dim=1, largest=True, sorted=True
|
634 |
+
)
|
635 |
+
|
636 |
+
return topk_indices, topk_values
|
637 |
+
|
638 |
+
@property
|
639 |
+
def debug(self):
|
640 |
+
return self._flags['debug']
|
641 |
+
|
642 |
+
@debug.setter
|
643 |
+
def debug(self, value):
|
644 |
+
self._flags['debug'] = value
|
645 |
+
|
646 |
+
@property
|
647 |
+
def model_stats(self):
|
648 |
+
return self._flags['model_stats']
|
649 |
+
|
650 |
+
@model_stats.setter
|
651 |
+
def model_stats(self, value):
|
652 |
+
self._flags['model_stats'] = value
|
653 |
+
|
654 |
+
def preprocess_image(self, image_path, image_size=512):
|
655 |
+
"""Process an image for inference using same preprocessing as training"""
|
656 |
+
if not os.path.exists(image_path):
|
657 |
+
raise ValueError(f"Image not found at path: {image_path}")
|
658 |
+
|
659 |
+
# Initialize the same transform used during training
|
660 |
+
transform = transforms.Compose([
|
661 |
+
transforms.ToTensor(),
|
662 |
+
])
|
663 |
+
|
664 |
+
try:
|
665 |
+
with Image.open(image_path) as img:
|
666 |
+
# Convert RGBA or Palette images to RGB
|
667 |
+
if img.mode in ('RGBA', 'P'):
|
668 |
+
img = img.convert('RGB')
|
669 |
+
|
670 |
+
# Get original dimensions
|
671 |
+
width, height = img.size
|
672 |
+
aspect_ratio = width / height
|
673 |
+
|
674 |
+
# Calculate new dimensions to maintain aspect ratio
|
675 |
+
if aspect_ratio > 1:
|
676 |
+
new_width = image_size
|
677 |
+
new_height = int(new_width / aspect_ratio)
|
678 |
+
else:
|
679 |
+
new_height = image_size
|
680 |
+
new_width = int(new_height * aspect_ratio)
|
681 |
+
|
682 |
+
# Resize with LANCZOS filter
|
683 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
684 |
+
|
685 |
+
# Create new image with padding
|
686 |
+
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
687 |
+
paste_x = (image_size - new_width) // 2
|
688 |
+
paste_y = (image_size - new_height) // 2
|
689 |
+
new_image.paste(img, (paste_x, paste_y))
|
690 |
+
|
691 |
+
# Apply transforms (without normalization)
|
692 |
+
img_tensor = transform(new_image)
|
693 |
+
return img_tensor
|
694 |
+
except Exception as e:
|
695 |
+
raise Exception(f"Error processing {image_path}: {str(e)}")
|
696 |
+
|
697 |
+
def forward(self, x):
|
698 |
+
"""Forward pass with simplified feature handling"""
|
699 |
+
# Initialize tracking dicts
|
700 |
+
model_stats = {} if self.model_stats else {}
|
701 |
+
debug_tensors = {} if self.debug else None
|
702 |
+
|
703 |
+
# 1. Image Feature Extraction
|
704 |
+
features = self.backbone.features(x)
|
705 |
+
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
706 |
+
|
707 |
+
# 2. Initial Tag Predictions
|
708 |
+
initial_logits = self.initial_classifier(features)
|
709 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
710 |
+
|
711 |
+
# 3. Tag Selection & Embedding (simplified)
|
712 |
+
pred_tag_indices, _ = self._get_selected_tags(initial_preds)
|
713 |
+
tag_embeddings = self.tag_embedding(pred_tag_indices)
|
714 |
+
|
715 |
+
# 4. Self-Attention on Tags
|
716 |
+
attended_tags = self.tag_attention(tag_embeddings)
|
717 |
+
attended_tags = self.tag_norm(attended_tags)
|
718 |
+
|
719 |
+
# 5. Cross-Attention between Features and Tags
|
720 |
+
features_proj = self.cross_proj(features)
|
721 |
+
features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1)
|
722 |
+
|
723 |
+
cross_attended = self.cross_attention(features_expanded, attended_tags)
|
724 |
+
cross_attended = self.cross_norm(cross_attended)
|
725 |
+
|
726 |
+
# 6. Feature Fusion with Residual Connection
|
727 |
+
fused_features = cross_attended.mean(dim=1) # Average across tag dimension
|
728 |
+
# Concatenate original and attended features
|
729 |
+
combined_features = torch.cat([features, fused_features], dim=-1)
|
730 |
+
|
731 |
+
# 7. Refined Predictions
|
732 |
+
refined_logits = self.refined_classifier(combined_features)
|
733 |
+
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
734 |
+
|
735 |
+
# Return both prediction sets
|
736 |
+
return initial_preds, refined_preds
|
737 |
+
|
738 |
+
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
739 |
+
"""
|
740 |
+
Run inference on an image with support for category-specific thresholds.
|
741 |
+
"""
|
742 |
+
# Preprocess the image
|
743 |
+
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
744 |
+
|
745 |
+
# Move to the same device as model and convert to half precision
|
746 |
+
device = next(self.parameters()).device
|
747 |
+
dtype = next(self.parameters()).dtype # Match model's precision
|
748 |
+
img_tensor = img_tensor.to(device, dtype=dtype)
|
749 |
+
|
750 |
+
# Run inference
|
751 |
+
with torch.no_grad():
|
752 |
+
initial_preds, refined_preds = self.forward(img_tensor)
|
753 |
+
|
754 |
+
# Apply sigmoid to get probabilities
|
755 |
+
initial_probs = torch.sigmoid(initial_preds)
|
756 |
+
refined_probs = torch.sigmoid(refined_preds)
|
757 |
+
|
758 |
+
# Apply thresholds
|
759 |
+
if category_thresholds:
|
760 |
+
# Create binary prediction tensors
|
761 |
+
refined_binary = torch.zeros_like(refined_probs)
|
762 |
+
|
763 |
+
# Apply thresholds by category
|
764 |
+
for category, cat_threshold in category_thresholds.items():
|
765 |
+
# Create a mask for tags in this category
|
766 |
+
category_mask = torch.zeros_like(refined_probs, dtype=torch.bool)
|
767 |
+
|
768 |
+
# Find indices for this category
|
769 |
+
for tag_idx in range(refined_probs.size(-1)):
|
770 |
+
try:
|
771 |
+
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
772 |
+
if tag_category == category:
|
773 |
+
category_mask[:, tag_idx] = True
|
774 |
+
except:
|
775 |
+
continue
|
776 |
+
|
777 |
+
# Apply threshold only to tags in this category - ensure dtype consistency
|
778 |
+
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
779 |
+
refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
780 |
+
|
781 |
+
predictions = refined_binary
|
782 |
+
else:
|
783 |
+
# Use the same threshold for all tags
|
784 |
+
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
785 |
+
predictions = (refined_probs >= threshold_tensor).to(dtype)
|
786 |
+
|
787 |
+
# Return both probabilities and thresholded predictions
|
788 |
+
return {
|
789 |
+
'initial_probabilities': initial_probs,
|
790 |
+
'refined_probabilities': refined_probs,
|
791 |
+
'predictions': predictions
|
792 |
+
}
|
793 |
+
|
794 |
+
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
795 |
+
"""
|
796 |
+
Convert model predictions to human-readable tags grouped by category.
|
797 |
+
"""
|
798 |
+
# Get non-zero predictions
|
799 |
+
if predictions.dim() > 1:
|
800 |
+
predictions = predictions[0] # Remove batch dimension
|
801 |
+
|
802 |
+
# Get indices of positive predictions
|
803 |
+
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
804 |
+
|
805 |
+
# Group by category
|
806 |
+
result = {}
|
807 |
+
for idx in indices:
|
808 |
+
tag_name, category = self.dataset.get_tag_info(idx)
|
809 |
+
|
810 |
+
if category not in result:
|
811 |
+
result[category] = []
|
812 |
+
|
813 |
+
if include_probabilities:
|
814 |
+
prob = predictions[idx].item()
|
815 |
+
result[category].append((tag_name, prob))
|
816 |
+
else:
|
817 |
+
result[category].append(tag_name)
|
818 |
+
|
819 |
+
# Sort tags by probability within each category
|
820 |
+
if include_probabilities:
|
821 |
+
for category in result:
|
822 |
+
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
823 |
+
|
824 |
+
return result
|
825 |
+
|
826 |
+
def load_model(model_dir, device='cuda'):
|
827 |
+
"""Load model with better error handling and warnings"""
|
828 |
+
print(f"Loading model from {model_dir}")
|
829 |
+
|
830 |
+
try:
|
831 |
+
# Load metadata
|
832 |
+
metadata_path = os.path.join(model_dir, "metadata.json")
|
833 |
+
if not os.path.exists(metadata_path):
|
834 |
+
raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
|
835 |
+
|
836 |
+
with open(metadata_path, 'r') as f:
|
837 |
+
metadata = json.load(f)
|
838 |
+
|
839 |
+
# Load model info
|
840 |
+
model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
|
841 |
+
if os.path.exists(model_info_path):
|
842 |
+
with open(model_info_path, 'r') as f:
|
843 |
+
model_info = json.load(f)
|
844 |
+
else:
|
845 |
+
print("WARNING: Model info file not found, using default settings")
|
846 |
+
model_info = {
|
847 |
+
"tag_context_size": 256,
|
848 |
+
"num_heads": 16,
|
849 |
+
"precision": "float16"
|
850 |
+
}
|
851 |
+
|
852 |
+
# Create dataset wrapper
|
853 |
+
dataset = TagDataset(
|
854 |
+
total_tags=metadata['total_tags'],
|
855 |
+
idx_to_tag=metadata['idx_to_tag'],
|
856 |
+
tag_to_category=metadata['tag_to_category']
|
857 |
+
)
|
858 |
+
|
859 |
+
# Initialize model with exact settings from model_info
|
860 |
+
model = ImageTagger(
|
861 |
+
total_tags=metadata['total_tags'],
|
862 |
+
dataset=dataset,
|
863 |
+
num_heads=model_info.get('num_heads', 16),
|
864 |
+
tag_context_size=model_info.get('tag_context_size', 256),
|
865 |
+
pretrained=False
|
866 |
+
)
|
867 |
+
|
868 |
+
# Load weights
|
869 |
+
state_dict_path = os.path.join(model_dir, "model.pt")
|
870 |
+
if not os.path.exists(state_dict_path):
|
871 |
+
raise FileNotFoundError(f"Model state dict not found at {state_dict_path}")
|
872 |
+
|
873 |
+
state_dict = torch.load(state_dict_path, map_location=device)
|
874 |
+
|
875 |
+
# First try strict loading
|
876 |
+
try:
|
877 |
+
model.load_state_dict(state_dict, strict=True)
|
878 |
+
print("✓ Model state dict loaded with strict=True successfully")
|
879 |
+
except Exception as e:
|
880 |
+
print(f"! Strict loading failed: {str(e)}")
|
881 |
+
print("Attempting non-strict loading...")
|
882 |
+
|
883 |
+
# Try non-strict loading
|
884 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
885 |
+
|
886 |
+
print(f"Non-strict loading completed with:")
|
887 |
+
print(f"- {len(missing_keys)} missing keys")
|
888 |
+
print(f"- {len(unexpected_keys)} unexpected keys")
|
889 |
+
|
890 |
+
if len(missing_keys) > 0:
|
891 |
+
print(f"Sample missing keys: {missing_keys[:5]}")
|
892 |
+
if len(unexpected_keys) > 0:
|
893 |
+
print(f"Sample unexpected keys: {unexpected_keys[:5]}")
|
894 |
+
|
895 |
+
# Move model to device
|
896 |
+
model = model.to(device)
|
897 |
+
|
898 |
+
# Set to half precision if needed
|
899 |
+
if model_info.get('precision') == 'float16':
|
900 |
+
model = model.half()
|
901 |
+
print("✓ Model converted to half precision")
|
902 |
+
|
903 |
+
# Set to eval mode
|
904 |
+
model.eval()
|
905 |
+
print("✓ Model set to evaluation mode")
|
906 |
+
|
907 |
+
# Verify parameter dtype
|
908 |
+
param_dtype = next(model.parameters()).dtype
|
909 |
+
print(f"✓ Model loaded with precision: {param_dtype}")
|
910 |
+
|
911 |
+
return model, dataset
|
912 |
+
|
913 |
+
except Exception as e:
|
914 |
+
print(f"ERROR loading model: {str(e)}")
|
915 |
+
import traceback
|
916 |
+
traceback.print_exc()
|
917 |
+
raise
|
918 |
+
|
919 |
+
# Example usage
|
920 |
+
if __name__ == "__main__":
|
921 |
+
import sys
|
922 |
+
|
923 |
+
# Get model directory from command line or use default
|
924 |
+
model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model"
|
925 |
+
|
926 |
+
# Load model
|
927 |
+
model, dataset, thresholds = load_model(model_dir)
|
928 |
+
|
929 |
+
# Display info
|
930 |
+
print(f"\nModel information:")
|
931 |
+
print(f" Total tags: {dataset.total_tags}")
|
932 |
+
print(f" Device: {next(model.parameters()).device}")
|
933 |
+
print(f" Precision: {next(model.parameters()).dtype}")
|
934 |
+
|
935 |
+
# Test on an image if provided
|
936 |
+
if len(sys.argv) > 2:
|
937 |
+
image_path = sys.argv[2]
|
938 |
+
print(f"\nRunning inference on {image_path}")
|
939 |
+
|
940 |
+
# Use category thresholds if available
|
941 |
+
if thresholds and 'categories' in thresholds:
|
942 |
+
category_thresholds = {cat: opt['balanced']['threshold']
|
943 |
+
for cat, opt in thresholds['categories'].items()}
|
944 |
+
results = model.predict(image_path, category_thresholds=category_thresholds)
|
945 |
+
else:
|
946 |
+
results = model.predict(image_path)
|
947 |
+
|
948 |
+
# Get tags
|
949 |
+
tags = model.get_tags_from_predictions(results['predictions'])
|
950 |
+
|
951 |
+
# Print tags by category
|
952 |
+
print("\nPredicted tags:")
|
953 |
+
for category, category_tags in tags.items():
|
954 |
+
print(f"\n{category.capitalize()}:")
|
955 |
+
for tag, prob in category_tags:
|
956 |
+
print(f" {tag}: {prob:.3f}")
|
model_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"class_name": "ImageTagger",
|
3 |
+
"args": {
|
4 |
+
"total_tags": 70527,
|
5 |
+
"num_heads": 16,
|
6 |
+
"dropout": 0.1,
|
7 |
+
"tag_context_size": 256
|
8 |
+
}
|
9 |
+
}
|
model_info_initial_only.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"precision": "float16",
|
3 |
+
"tag_context_size": 256,
|
4 |
+
"num_heads": 16,
|
5 |
+
"architecture": "ImageTagger",
|
6 |
+
"embedding_dim": 1280,
|
7 |
+
"backbone": "efficientnet_v2_l",
|
8 |
+
"model_type": "initial_only"
|
9 |
+
}
|
model_no_flash.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
5 |
+
|
6 |
+
class MultiheadAttentionNoFlash(nn.Module):
|
7 |
+
"""Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
|
8 |
+
def __init__(self, dim, num_heads=8, dropout=0.0):
|
9 |
+
super().__init__()
|
10 |
+
assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
|
11 |
+
self.dim = dim
|
12 |
+
self.num_heads = num_heads
|
13 |
+
self.head_dim = dim // num_heads
|
14 |
+
self.scale = self.head_dim ** -0.5 # scaling factor for dot-product attention
|
15 |
+
|
16 |
+
# Define separate projections for query, key, value, and output (no biases to match FlashAttention)
|
17 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
18 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
19 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
20 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
21 |
+
# (Note: We omit dropout in attention computation for ONNX simplicity; model should be set to eval mode anyway.)
|
22 |
+
|
23 |
+
def forward(self, query, key=None, value=None):
|
24 |
+
# Allow usage as self-attention if key/value not provided
|
25 |
+
if key is None:
|
26 |
+
key = query
|
27 |
+
if value is None:
|
28 |
+
value = key
|
29 |
+
|
30 |
+
# Linear projections
|
31 |
+
Q = self.q_proj(query) # [B, S_q, dim]
|
32 |
+
K = self.k_proj(key) # [B, S_k, dim]
|
33 |
+
V = self.v_proj(value) # [B, S_v, dim]
|
34 |
+
|
35 |
+
# Reshape into (B, num_heads, S, head_dim) for computing attention per head
|
36 |
+
B, S_q, _ = Q.shape
|
37 |
+
_, S_k, _ = K.shape
|
38 |
+
Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_q, head_dim]
|
39 |
+
K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
|
40 |
+
V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
|
41 |
+
|
42 |
+
# Scaled dot-product attention: compute attention weights
|
43 |
+
attn_weights = torch.matmul(Q, K.transpose(2, 3)) # [B, heads, S_q, S_k]
|
44 |
+
attn_weights = attn_weights * self.scale
|
45 |
+
attn_probs = F.softmax(attn_weights, dim=-1) # softmax over S_k (key length)
|
46 |
+
|
47 |
+
# Apply attention weights to values
|
48 |
+
attn_output = torch.matmul(attn_probs, V) # [B, heads, S_q, head_dim]
|
49 |
+
|
50 |
+
# Reshape back to [B, S_q, dim]
|
51 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
|
52 |
+
# Output projection
|
53 |
+
output = self.out_proj(attn_output) # [B, S_q, dim]
|
54 |
+
return output
|
55 |
+
|
56 |
+
class ImageTaggerRefinedONNX(nn.Module):
|
57 |
+
"""
|
58 |
+
Refined CAMIE Image Tagger model without FlashAttention.
|
59 |
+
- EfficientNetV2 backbone
|
60 |
+
- Initial classifier for preliminary tag logits
|
61 |
+
- Multi-head self-attention on top predicted tag embeddings
|
62 |
+
- Multi-head cross-attention between image feature and tag embeddings
|
63 |
+
- Refined classifier for final tag logits
|
64 |
+
"""
|
65 |
+
def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
|
66 |
+
super().__init__()
|
67 |
+
self.tag_context_size = tag_context_size
|
68 |
+
self.embedding_dim = 1280 # EfficientNetV2-L feature dimension
|
69 |
+
|
70 |
+
# Backbone feature extractor (EfficientNetV2-L)
|
71 |
+
backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
|
72 |
+
backbone.classifier = nn.Identity() # remove final classification head
|
73 |
+
self.backbone = backbone
|
74 |
+
|
75 |
+
# Spatial pooling to get a single feature vector per image (1x1 avg pool)
|
76 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
77 |
+
|
78 |
+
# Initial classifier (two-layer MLP) to predict tags from image feature
|
79 |
+
self.initial_classifier = nn.Sequential(
|
80 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
81 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
82 |
+
nn.GELU(),
|
83 |
+
nn.Dropout(dropout),
|
84 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
85 |
+
nn.LayerNorm(self.embedding_dim),
|
86 |
+
nn.GELU(),
|
87 |
+
nn.Linear(self.embedding_dim, total_tags) # outputs raw logits for all tags
|
88 |
+
)
|
89 |
+
|
90 |
+
# Embedding for tags (each tag gets an embedding vector, used for attention)
|
91 |
+
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
92 |
+
|
93 |
+
# Self-attention over the selected tag embeddings (replaces FlashAttention)
|
94 |
+
self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
95 |
+
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
96 |
+
|
97 |
+
# Projection from image feature to query vector for cross-attention
|
98 |
+
self.cross_proj = nn.Sequential(
|
99 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
100 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
101 |
+
nn.GELU(),
|
102 |
+
nn.Dropout(dropout),
|
103 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
104 |
+
)
|
105 |
+
# Cross-attention between image feature (as query) and tag features (as key/value)
|
106 |
+
self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
107 |
+
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
108 |
+
|
109 |
+
# Refined classifier (takes concatenated original & attended features)
|
110 |
+
self.refined_classifier = nn.Sequential(
|
111 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
|
112 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
113 |
+
nn.GELU(),
|
114 |
+
nn.Dropout(dropout),
|
115 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
116 |
+
nn.LayerNorm(self.embedding_dim),
|
117 |
+
nn.GELU(),
|
118 |
+
nn.Linear(self.embedding_dim, total_tags)
|
119 |
+
)
|
120 |
+
|
121 |
+
# Temperature parameter for scaling logits (to calibrate confidence)
|
122 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
123 |
+
|
124 |
+
def forward(self, images):
|
125 |
+
# 1. Feature extraction
|
126 |
+
feats = self.backbone.features(images) # [B, 1280, H/32, W/32] features
|
127 |
+
feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1) # [B, 1280] global feature vector per image
|
128 |
+
|
129 |
+
# 2. Initial tag prediction
|
130 |
+
initial_logits = self.initial_classifier(feats) # [B, total_tags]
|
131 |
+
# Scale by temperature and clamp (to stabilize extreme values, as in original)
|
132 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
133 |
+
|
134 |
+
# 3. Select top-k predicted tags for context (tag_context_size)
|
135 |
+
probs = torch.sigmoid(initial_preds) # convert logits to probabilities
|
136 |
+
# Get indices of top `tag_context_size` tags for each sample
|
137 |
+
_, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
|
138 |
+
# 4. Embed selected tags
|
139 |
+
tag_embeds = self.tag_embedding(topk_indices) # [B, tag_context_size, embedding_dim]
|
140 |
+
|
141 |
+
# 5. Self-attention on tag embeddings (to refine tag representation)
|
142 |
+
attn_tags = self.tag_attention(tag_embeds) # [B, tag_context_size, embedding_dim]
|
143 |
+
attn_tags = self.tag_norm(attn_tags) # layer norm
|
144 |
+
|
145 |
+
# 6. Cross-attention between image feature and attended tags
|
146 |
+
# Expand image features to have one per tag position
|
147 |
+
feat_q = self.cross_proj(feats) # [B, embedding_dim]
|
148 |
+
# Repeat each image feature vector tag_context_size times to form a sequence
|
149 |
+
feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1) # [B, tag_context_size, embedding_dim]
|
150 |
+
# Use image features as queries, tag embeddings as keys and values
|
151 |
+
cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags) # [B, tag_context_size, embedding_dim]
|
152 |
+
cross_attn = self.cross_norm(cross_attn)
|
153 |
+
|
154 |
+
# 7. Fuse features: average the cross-attended tag outputs, and combine with original features
|
155 |
+
fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim]
|
156 |
+
combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2]
|
157 |
+
|
158 |
+
# 8. Refined tag prediction
|
159 |
+
refined_logits = self.refined_classifier(combined) # [B, total_tags]
|
160 |
+
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
161 |
+
|
162 |
+
return initial_preds, refined_preds
|
163 |
+
|
164 |
+
# --- Load the pretrained refined model weights ---
|
165 |
+
total_tags = 70527 # total number of tags in the dataset (Danbooru 2024)
|
166 |
+
from safetensors.torch import load_file
|
167 |
+
safetensors_path = 'model_refined.safetensors'
|
168 |
+
state_dict = load_file(safetensors_path, device='cpu') # Load the saved weights (should be an OrderedDict)
|
169 |
+
#state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict)
|
170 |
+
|
171 |
+
# Initialize our model and load weights
|
172 |
+
model = ImageTaggerRefinedONNX(total_tags=total_tags)
|
173 |
+
model.load_state_dict(state_dict)
|
174 |
+
model.eval() # set to evaluation mode (disable dropout)
|
175 |
+
|
176 |
+
# (Optional) Cast to float32 if weights were in half precision
|
177 |
+
# model = model.float()
|
178 |
+
|
179 |
+
# --- Export to ONNX ---
|
180 |
+
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False) # dummy batch of 1 image (3x512x512)
|
181 |
+
output_onnx_file = "camie_refined_no_flash_v15.onnx"
|
182 |
+
torch.onnx.export(
|
183 |
+
model, dummy_input, output_onnx_file,
|
184 |
+
export_params=True, # store trained parameter weights inside the model file
|
185 |
+
opset_version=17, # ONNX opset version (ensure support for needed ops)
|
186 |
+
do_constant_folding=True, # optimize constant expressions
|
187 |
+
input_names=["image"],
|
188 |
+
output_names=["initial_tags", "refined_tags"],
|
189 |
+
dynamic_axes={ # set batch dimension to be dynamic
|
190 |
+
"image": {0: "batch"},
|
191 |
+
"initial_tags": {0: "batch"},
|
192 |
+
"refined_tags": {0: "batch"}
|
193 |
+
}
|
194 |
+
)
|
195 |
+
print(f"ONNX model exported to {output_onnx_file}")
|
thresholds.json
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"overall": {
|
3 |
+
"balanced": {
|
4 |
+
"threshold": 0.3285714089870453,
|
5 |
+
"f1": 0.6128875755303665,
|
6 |
+
"precision": 0.6348684210526315,
|
7 |
+
"recall": 0.5923778668258164
|
8 |
+
},
|
9 |
+
"high_precision": {
|
10 |
+
"threshold": 0.48367345333099365,
|
11 |
+
"f1": 0.5073781135639239,
|
12 |
+
"precision": 0.8244772683675426,
|
13 |
+
"recall": 0.3664421519311109
|
14 |
+
},
|
15 |
+
"high_recall": {
|
16 |
+
"threshold": 0.20612245798110962,
|
17 |
+
"f1": 0.5140483341286104,
|
18 |
+
"precision": 0.38317013976064945,
|
19 |
+
"recall": 0.7807144684116293
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"weighted": {
|
23 |
+
"f1": {
|
24 |
+
"threshold": 0.31224489212036133,
|
25 |
+
"value": 0.666115043816508
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"categories": {
|
29 |
+
"copyright": {
|
30 |
+
"balanced": {
|
31 |
+
"threshold": 0.3857142925262451,
|
32 |
+
"f1": 0.7885196374622356,
|
33 |
+
"precision": 0.903114186851211,
|
34 |
+
"recall": 0.6997319034852547
|
35 |
+
},
|
36 |
+
"high_precision": {
|
37 |
+
"threshold": 0.5,
|
38 |
+
"f1": 0.7524429967426711,
|
39 |
+
"precision": 0.9585062240663901,
|
40 |
+
"recall": 0.6193029490616622
|
41 |
+
},
|
42 |
+
"high_recall": {
|
43 |
+
"threshold": 0.13265305757522583,
|
44 |
+
"f1": 0.5149136577708007,
|
45 |
+
"precision": 0.36403995560488345,
|
46 |
+
"recall": 0.8793565683646113
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"character": {
|
50 |
+
"balanced": {
|
51 |
+
"threshold": 0.30408161878585815,
|
52 |
+
"f1": 0.769028871391076,
|
53 |
+
"precision": 0.8878787878787879,
|
54 |
+
"recall": 0.6782407407407407
|
55 |
+
},
|
56 |
+
"high_precision": {
|
57 |
+
"threshold": 0.47551020979881287,
|
58 |
+
"f1": 0.7128129602356407,
|
59 |
+
"precision": 0.979757085020243,
|
60 |
+
"recall": 0.5601851851851852
|
61 |
+
},
|
62 |
+
"high_recall": {
|
63 |
+
"threshold": 0.13265305757522583,
|
64 |
+
"f1": 0.5132616487455197,
|
65 |
+
"precision": 0.37175493250259606,
|
66 |
+
"recall": 0.8287037037037037
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"general": {
|
70 |
+
"balanced": {
|
71 |
+
"threshold": 0.3285714089870453,
|
72 |
+
"f1": 0.6070014256296532,
|
73 |
+
"precision": 0.6206003023105161,
|
74 |
+
"recall": 0.5939857393820399
|
75 |
+
},
|
76 |
+
"high_precision": {
|
77 |
+
"threshold": 0.47551020979881287,
|
78 |
+
"f1": 0.5074963046385584,
|
79 |
+
"precision": 0.7958057395143487,
|
80 |
+
"recall": 0.3725328097550894
|
81 |
+
},
|
82 |
+
"high_recall": {
|
83 |
+
"threshold": 0.20612245798110962,
|
84 |
+
"f1": 0.5094889521485699,
|
85 |
+
"precision": 0.3790529978316777,
|
86 |
+
"recall": 0.7767903275808619
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"meta": {
|
90 |
+
"balanced": {
|
91 |
+
"threshold": 0.31224489212036133,
|
92 |
+
"f1": 0.5943152454780362,
|
93 |
+
"precision": 0.5948275862068966,
|
94 |
+
"recall": 0.5938037865748709
|
95 |
+
},
|
96 |
+
"high_precision": {
|
97 |
+
"threshold": 0.41020408272743225,
|
98 |
+
"f1": 0.5087924970691676,
|
99 |
+
"precision": 0.7977941176470589,
|
100 |
+
"recall": 0.37349397590361444
|
101 |
+
},
|
102 |
+
"high_recall": {
|
103 |
+
"threshold": 0.22244898974895477,
|
104 |
+
"f1": 0.5037433155080214,
|
105 |
+
"precision": 0.365399534522886,
|
106 |
+
"recall": 0.810671256454389
|
107 |
+
}
|
108 |
+
},
|
109 |
+
"rating": {
|
110 |
+
"balanced": {
|
111 |
+
"threshold": 0.34489795565605164,
|
112 |
+
"f1": 0.7964912280701754,
|
113 |
+
"precision": 0.7229299363057324,
|
114 |
+
"recall": 0.88671875
|
115 |
+
},
|
116 |
+
"high_precision": {
|
117 |
+
"threshold": 0.5,
|
118 |
+
"f1": 0.6966824644549763,
|
119 |
+
"precision": 0.8855421686746988,
|
120 |
+
"recall": 0.57421875
|
121 |
+
},
|
122 |
+
"high_recall": {
|
123 |
+
"threshold": 0.10000000149011612,
|
124 |
+
"f1": 0.6538952745849297,
|
125 |
+
"precision": 0.4857685009487666,
|
126 |
+
"recall": 1.0
|
127 |
+
}
|
128 |
+
},
|
129 |
+
"artist": {
|
130 |
+
"balanced": {
|
131 |
+
"threshold": 0.22244898974895477,
|
132 |
+
"f1": 0.5017921146953405,
|
133 |
+
"precision": 0.56,
|
134 |
+
"recall": 0.45454545454545453
|
135 |
+
},
|
136 |
+
"high_precision": {
|
137 |
+
"threshold": 0.22244898974895477,
|
138 |
+
"f1": 0.5017921146953405,
|
139 |
+
"precision": 0.56,
|
140 |
+
"recall": 0.45454545454545453
|
141 |
+
},
|
142 |
+
"high_recall": {
|
143 |
+
"threshold": 0.22244898974895477,
|
144 |
+
"f1": 0.5017921146953405,
|
145 |
+
"precision": 0.56,
|
146 |
+
"recall": 0.45454545454545453
|
147 |
+
}
|
148 |
+
},
|
149 |
+
"year": {
|
150 |
+
"balanced": {
|
151 |
+
"threshold": 0.2877551317214966,
|
152 |
+
"f1": 0.32867132867132864,
|
153 |
+
"precision": 0.2974683544303797,
|
154 |
+
"recall": 0.3671875
|
155 |
+
},
|
156 |
+
"high_precision": {
|
157 |
+
"threshold": 0,
|
158 |
+
"f1": 0,
|
159 |
+
"precision": 0,
|
160 |
+
"recall": 0
|
161 |
+
},
|
162 |
+
"high_recall": {
|
163 |
+
"threshold": 0,
|
164 |
+
"f1": 0,
|
165 |
+
"precision": 0,
|
166 |
+
"recall": 0
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
}
|