AngelBottomless commited on
Commit
a7ab59e
·
verified ·
1 Parent(s): f7165bd

Upload 9 files

Browse files
Files changed (8) hide show
  1. export.py +30 -0
  2. infer-refined.py +89 -35
  3. infer.py +139 -97
  4. model_code.py +956 -0
  5. model_config.json +9 -0
  6. model_info_initial_only.json +9 -0
  7. model_no_flash.py +195 -0
  8. thresholds.json +170 -0
export.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.models as models
3
+ from model_code import InitialOnlyImageTagger # Assume model_code.py classes are accessible
4
+ from safetensors.torch import load_file
5
+
6
+ # Load the trained weights (Initial-only model). Adjust path to your weights file.
7
+ #weights_path = "model_initial_only.pt"
8
+ safetensors_path = 'model_initial.safetensors'
9
+ state_dict = load_file(safetensors_path, device='cpu')
10
+ #state_dict = torch.load(weights_path, map_location="cpu")
11
+ # Instantiate the model with the same parameters as training
12
+ model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True) # dataset not needed for forward
13
+ model.load_state_dict(state_dict)
14
+ model.eval() # set to evaluation mode
15
+
16
+ # Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
17
+ dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
18
+
19
+ # Export to ONNX
20
+ onnx_path = "camie_tagger_initial_v15.onnx"
21
+ torch.onnx.export(
22
+ model, dummy_input, onnx_path,
23
+ export_params=True, # store the trained parameter weights in the model file
24
+ opset_version=13, # ONNX opset version (13 is widely supported)
25
+ do_constant_folding=True, # optimize constant expressions
26
+ input_names=["input"],
27
+ output_names=["initial_logits", "refined_logits"], # model.forward returns two outputs (identical for InitialOnly)
28
+ dynamic_axes={"input": {0: "batch_size"}} # allow variable batch size
29
+ )
30
+ print(f"ONNX model saved to: {onnx_path}")
infer-refined.py CHANGED
@@ -42,73 +42,120 @@ def preprocess_image(img_path, target_size=512, keep_aspect=True):
42
  arr = np.expand_dims(arr, axis=0)
43
  return arr
44
 
45
- def onnx_inference(img_paths,
46
- onnx_path="camie_refined_no_flash.onnx",
47
- threshold=0.325,
48
- metadata_file="metadata.json"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
  Loads the ONNX model, runs inference on a list of image paths,
51
- and applies an optional threshold to produce final predictions.
52
-
53
  Args:
54
- img_paths: List of paths to images.
55
- onnx_path: Path to the exported ONNX model file.
56
- threshold: Probability threshold for deciding if a tag is predicted.
57
- metadata_file: Path to metadata.json that contains idx_to_tag etc.
58
-
 
 
 
59
  Returns:
60
- A list of dicts, each containing:
61
  {
62
  "initial_logits": np.ndarray of shape (N_tags,),
63
  "refined_logits": np.ndarray of shape (N_tags,),
64
- "predicted_tags": list of tag indices that exceeded threshold,
 
65
  ...
66
  }
67
- one dict per input image.
68
  """
69
  # 1) Initialize ONNX runtime session
70
  session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
71
- # Optional: for GPU usage, see if "CUDAExecutionProvider" is available
72
  # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
73
 
74
  # 2) Pre-load metadata
75
  with open(metadata_file, "r", encoding="utf-8") as f:
76
  metadata = json.load(f)
77
- idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
 
 
 
 
78
 
79
  # 3) Preprocess each image into a batch
80
  batch_tensors = []
81
  for img_path in img_paths:
82
- x = preprocess_image(img_path, target_size=512, keep_aspect=True)
83
  batch_tensors.append(x)
84
- # Concatenate along the batch dimension => shape (batch_size, 3, 512, 512)
85
  batch_input = np.concatenate(batch_tensors, axis=0)
86
 
87
  # 4) Run inference
88
- input_name = session.get_inputs()[0].name # typically "image"
89
  outputs = session.run(None, {input_name: batch_input})
90
  # Typically we get [initial_tags, refined_tags] as output
91
- initial_preds, refined_preds = outputs # shapes => (batch_size, 70527)
92
 
93
- # 5) For each image in batch, convert logits to predictions if desired
94
  batch_results = []
95
  for i in range(initial_preds.shape[0]):
96
- # Extract one sample's logits
97
  init_logit = initial_preds[i, :] # shape (N_tags,)
98
  ref_logit = refined_preds[i, :] # shape (N_tags,)
 
 
 
 
99
 
100
- # Convert to probabilities with sigmoid
101
- ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
 
 
 
102
 
103
- # Threshold
104
- pred_indices = np.where(ref_prob >= threshold)[0]
 
105
 
106
  # Build result for this image
107
  result_dict = {
108
  "initial_logits": init_logit,
109
  "refined_logits": ref_logit,
110
- "predicted_indices": pred_indices,
111
- "predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name
112
  }
113
  batch_results.append(result_dict)
114
 
@@ -116,14 +163,21 @@ def onnx_inference(img_paths,
116
 
117
  if __name__ == "__main__":
118
  # Example usage
119
- images = ["image1.jpg", "image2.jpg", "image3.jpg"]
120
- results = onnx_inference(images,
121
- onnx_path="camie_refined_no_flash.onnx",
122
- threshold=0.325,
123
- metadata_file="metadata.json")
 
 
 
 
 
124
 
125
  for i, res in enumerate(results):
126
  print(f"Image: {images[i]}")
127
  print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
128
- print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)")
129
- print()
 
 
 
42
  arr = np.expand_dims(arr, axis=0)
43
  return arr
44
 
45
+ # Example input
46
+ def load_thresholds(threshold_json_path, mode="balanced"):
47
+ """
48
+ Loads thresholds from the given JSON file, using a particular mode
49
+ (e.g. 'balanced', 'high_precision', 'high_recall') for each category.
50
+
51
+ Returns:
52
+ thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
53
+ fallback_threshold (float): The overall threshold if category not found
54
+ """
55
+ with open(threshold_json_path, "r", encoding="utf-8") as f:
56
+ data = json.load(f)
57
+
58
+ # The fallback threshold from the "overall" section for the chosen mode
59
+ fallback_threshold = data["overall"][mode]["threshold"]
60
+
61
+ # Build a dict of thresholds keyed by category
62
+ thresholds_by_category = {}
63
+ if "categories" in data:
64
+ for cat_name, cat_modes in data["categories"].items():
65
+ # If the chosen mode is present for that category, use it;
66
+ # otherwise fall back to the "overall" threshold.
67
+ if mode in cat_modes and "threshold" in cat_modes[mode]:
68
+ thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
69
+ else:
70
+ thresholds_by_category[cat_name] = fallback_threshold
71
+
72
+ return thresholds_by_category, fallback_threshold
73
+ def onnx_inference(
74
+ img_paths,
75
+ onnx_path="camie_refined_no_flash.onnx",
76
+ metadata_file="metadata.json",
77
+ threshold_json_path="thresholds.json",
78
+ mode="balanced",
79
+ target_size=512,
80
+ keep_aspect=True
81
+ ):
82
  """
83
  Loads the ONNX model, runs inference on a list of image paths,
84
+ and applies category-wise thresholds from threshold.json (per the chosen mode).
85
+
86
  Args:
87
+ img_paths : List of paths to images.
88
+ onnx_path : Path to the exported ONNX model file.
89
+ metadata_file : Path to metadata.json that contains idx_to_tag, tag_to_category, etc.
90
+ threshold_json_path : Path to thresholds.json containing category-wise threshold info.
91
+ mode : "balanced", "high_precision", or "high_recall".
92
+ target_size : Final size of preprocessed images (512 by default).
93
+ keep_aspect : If True, preserve aspect ratio when resizing, pad with black.
94
+
95
  Returns:
96
+ A list of dicts, one per input image, each containing:
97
  {
98
  "initial_logits": np.ndarray of shape (N_tags,),
99
  "refined_logits": np.ndarray of shape (N_tags,),
100
+ "predicted_indices": list of tag indices that exceeded threshold,
101
+ "predicted_tags": list of predicted tag strings,
102
  ...
103
  }
 
104
  """
105
  # 1) Initialize ONNX runtime session
106
  session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
107
+ # For GPU usage, you could do e.g.:
108
  # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
109
 
110
  # 2) Pre-load metadata
111
  with open(metadata_file, "r", encoding="utf-8") as f:
112
  metadata = json.load(f)
113
+ idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
114
+ tag_to_category = metadata.get("tag_to_category", {})
115
+
116
+ # Load thresholds from thresholds.json using the specified mode
117
+ thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
118
 
119
  # 3) Preprocess each image into a batch
120
  batch_tensors = []
121
  for img_path in img_paths:
122
+ x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect)
123
  batch_tensors.append(x)
124
+ # Concatenate along the batch dimension => shape (batch_size, 3, H, W)
125
  batch_input = np.concatenate(batch_tensors, axis=0)
126
 
127
  # 4) Run inference
128
+ input_name = session.get_inputs()[0].name # typically "image" or "input"
129
  outputs = session.run(None, {input_name: batch_input})
130
  # Typically we get [initial_tags, refined_tags] as output
131
+ initial_preds, refined_preds = outputs # shapes => (batch_size, N_tags)
132
 
133
+ # 5) Convert logits -> probabilities -> apply category-specific thresholds
134
  batch_results = []
135
  for i in range(initial_preds.shape[0]):
 
136
  init_logit = initial_preds[i, :] # shape (N_tags,)
137
  ref_logit = refined_preds[i, :] # shape (N_tags,)
138
+ ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # shape (N_tags,)
139
+
140
+ predicted_indices = []
141
+ predicted_tags = []
142
 
143
+ # Check each tag against the category threshold
144
+ for idx in range(ref_logit.shape[0]):
145
+ tag_name = idx_to_tag[str(idx)] # Convert index->string->tag name
146
+ category = tag_to_category.get(tag_name, "general") # fallback to "general" if missing
147
+ cat_threshold = thresholds_by_category.get(category, fallback_threshold)
148
 
149
+ if ref_prob[idx] >= cat_threshold:
150
+ predicted_indices.append(idx)
151
+ predicted_tags.append(tag_name)
152
 
153
  # Build result for this image
154
  result_dict = {
155
  "initial_logits": init_logit,
156
  "refined_logits": ref_logit,
157
+ "predicted_indices": predicted_indices,
158
+ "predicted_tags": predicted_tags,
159
  }
160
  batch_results.append(result_dict)
161
 
 
163
 
164
  if __name__ == "__main__":
165
  # Example usage
166
+ images = ["images.png"]
167
+ results = onnx_inference(
168
+ img_paths=images,
169
+ onnx_path="camie_refined_no_flash_v15.onnx",
170
+ metadata_file="metadata.json",
171
+ threshold_json_path="thresholds.json",
172
+ mode="balanced", # or "balanced", "high_precision"
173
+ target_size=512,
174
+ keep_aspect=True
175
+ )
176
 
177
  for i, res in enumerate(results):
178
  print(f"Image: {images[i]}")
179
  print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
180
+ # Show first 10 predicted tags (if available)
181
+ sample_tags = res['predicted_tags']
182
+ print(" Sample predicted tags:", sample_tags)
183
+ print()
infer.py CHANGED
@@ -1,98 +1,140 @@
1
- import onnxruntime as ort
2
- import numpy as np
3
- import json
4
- from PIL import Image
5
-
6
- # 1) Load ONNX model
7
- session = ort.InferenceSession("camie_tagger_initial.onnx", providers=["CPUExecutionProvider"])
8
-
9
- # 2) Preprocess your image (512x512, etc.)
10
- def preprocess_image(img_path):
11
- """
12
- Loads and resizes an image to 512x512, converts it to float32 [0..1],
13
- and returns a (1,3,512,512) NumPy array (NCHW format).
14
- """
15
- img = Image.open(img_path).convert("RGB").resize((512, 512))
16
- x = np.array(img).astype(np.float32) / 255.0
17
- x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
18
- x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
19
- return x
20
-
21
- # Example input
22
-
23
- def inference(input_path, output_format="verbose"):
24
- """
25
- Returns either:
26
- - A verbose category breakdown, or
27
- - A comma-separated string of predicted tags (underscores replaced with spaces).
28
- """
29
- # 1) Preprocess
30
- input_tensor = preprocess_image(input_path)
31
-
32
- # 2) Run inference
33
- input_name = session.get_inputs()[0].name
34
- outputs = session.run(None, {input_name: input_tensor})
35
- initial_logits, refined_logits = outputs # shape: (1, 70527) each
36
-
37
- # 3) Convert logits to probabilities
38
- refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
39
-
40
- # 4) Load metadata & retrieve threshold info
41
- with open("metadata.json", "r", encoding="utf-8") as f:
42
- metadata = json.load(f)
43
-
44
- idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
45
- tag_to_category = metadata.get("tag_to_category", {})
46
- category_thresholds = metadata.get(
47
- "category_thresholds",
48
- {"artist": 0.1, "character": 0.2, "meta": 0.3, "style": 0.1}
49
- )
50
- default_threshold = 0.325
51
-
52
- # 5) Collect predictions by category
53
- results_by_category = {}
54
- num_tags = refined_probs.shape[1]
55
-
56
- for i in range(num_tags):
57
- prob = float(refined_probs[0, i])
58
- tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
59
- category = tag_to_category.get(tag_name, "unknown")
60
- cat_threshold = category_thresholds.get(category, default_threshold)
61
-
62
- if prob >= cat_threshold:
63
- if category not in results_by_category:
64
- results_by_category[category] = []
65
- results_by_category[category].append((tag_name, prob))
66
-
67
- # 6) Depending on output_format, produce different return strings
68
- if output_format == "as_prompt":
69
- # Flatten all predicted tags across categories
70
- all_predicted_tags = []
71
- for cat, tags_list in results_by_category.items():
72
- # We only need the tag name in as_prompt format
73
- for tname, tprob in tags_list:
74
- # convert underscores to spaces
75
- tag_name_spaces = tname.replace("_", " ")
76
- all_predicted_tags.append(tag_name_spaces)
77
-
78
- # Create a comma-separated string
79
- prompt_string = ", ".join(all_predicted_tags)
80
- return prompt_string
81
-
82
- else: # "verbose"
83
- # We'll build a multiline string describing the predictions
84
- lines = []
85
- lines.append("Predicted Tags by Category:\n")
86
- for cat, tags_list in results_by_category.items():
87
- lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
88
- # Sort descending by probability
89
- for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
90
- lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
91
- lines.append("") # blank line after each category
92
- # Join lines with newlines
93
- verbose_output = "\n".join(lines)
94
- return verbose_output
95
-
96
- if __name__ == "__main__":
97
- result = inference("path/to/image", output_format="as_prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  print(result)
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import json
4
+ from PIL import Image
5
+
6
+ # 1) Load ONNX model
7
+ session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
8
+
9
+ # 2) Preprocess your image (512x512, etc.)
10
+ def preprocess_image(img_path):
11
+ """
12
+ Loads and resizes an image to 512x512, converts it to float32 [0..1],
13
+ and returns a (1,3,512,512) NumPy array (NCHW format).
14
+ """
15
+ img = Image.open(img_path).convert("RGB").resize((512, 512))
16
+ x = np.array(img).astype(np.float32) / 255.0
17
+ x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
18
+ x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
19
+ return x
20
+
21
+ # Example input
22
+ def load_thresholds(threshold_json_path, mode="balanced"):
23
+ """
24
+ Loads thresholds from the given JSON file, using a particular mode
25
+ (e.g. 'balanced', 'high_precision', 'high_recall') for each category.
26
+
27
+ Returns:
28
+ thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
29
+ fallback_threshold (float): The overall threshold if category not found
30
+ """
31
+ with open(threshold_json_path, "r", encoding="utf-8") as f:
32
+ data = json.load(f)
33
+
34
+ # The fallback threshold from the "overall" section for the chosen mode
35
+ fallback_threshold = data["overall"][mode]["threshold"]
36
+
37
+ # Build a dict of thresholds keyed by category
38
+ thresholds_by_category = {}
39
+ if "categories" in data:
40
+ for cat_name, cat_modes in data["categories"].items():
41
+ # If the chosen mode is present for that category, use it;
42
+ # otherwise fall back to the "overall" threshold.
43
+ if mode in cat_modes and "threshold" in cat_modes[mode]:
44
+ thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
45
+ else:
46
+ thresholds_by_category[cat_name] = fallback_threshold
47
+
48
+ return thresholds_by_category, fallback_threshold
49
+
50
+ def inference(
51
+ input_path,
52
+ output_format="verbose",
53
+ mode="balanced",
54
+ threshold_json_path="thresholds.json",
55
+ metadata_path="metadata.json"
56
+ ):
57
+ """
58
+ Run inference on an image using the loaded ONNX model, then apply
59
+ category-wise thresholds from `threshold.json` for the chosen mode.
60
+
61
+ Arguments:
62
+ input_path (str) : Path to the image file for inference.
63
+ output_format (str) : Either "verbose" or "as_prompt".
64
+ mode (str) : "balanced", "high_precision", or "high_recall"
65
+ threshold_json_path (str) : Path to the JSON file with category thresholds.
66
+ metadata_path (str) : Path to the metadata JSON file with category info.
67
+
68
+ Returns:
69
+ str: The predicted tags in either verbose or comma-separated format.
70
+ """
71
+ # 1) Preprocess
72
+ input_tensor = preprocess_image(input_path)
73
+
74
+ # 2) Run inference
75
+ input_name = session.get_inputs()[0].name
76
+ outputs = session.run(None, {input_name: input_tensor})
77
+ initial_logits, refined_logits = outputs # shape: (1, 70527) each
78
+
79
+ # 3) Convert logits to probabilities
80
+ refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
81
+
82
+ # 4) Load metadata & retrieve threshold info
83
+ with open(metadata_path, "r", encoding="utf-8") as f:
84
+ metadata = json.load(f)
85
+
86
+ idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
87
+ tag_to_category = metadata.get("tag_to_category", {})
88
+ # Load thresholds from threshold.json using the specified mode
89
+ thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
90
+
91
+
92
+ # 5) Collect predictions by category
93
+ results_by_category = {}
94
+ num_tags = refined_probs.shape[1]
95
+
96
+ for i in range(num_tags):
97
+ prob = float(refined_probs[0, i])
98
+ tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
99
+ category = tag_to_category.get(tag_name, "general")
100
+
101
+ # Determine the threshold to use for this category
102
+ cat_threshold = thresholds_by_category.get(category, fallback_threshold)
103
+
104
+ if prob >= cat_threshold:
105
+ if category not in results_by_category:
106
+ results_by_category[category] = []
107
+ results_by_category[category].append((tag_name, prob))
108
+
109
+ # 6) Depending on output_format, produce different return strings
110
+ if output_format == "as_prompt":
111
+ # Flatten all predicted tags across categories
112
+ all_predicted_tags = []
113
+ for cat, tags_list in results_by_category.items():
114
+ # We only need the tag name in as_prompt format
115
+ for tname, tprob in tags_list:
116
+ # convert underscores to spaces
117
+ tag_name_spaces = tname.replace("_", " ")
118
+ all_predicted_tags.append(tag_name_spaces)
119
+
120
+ # Create a comma-separated string
121
+ prompt_string = ", ".join(all_predicted_tags)
122
+ return prompt_string
123
+
124
+ else: # "verbose"
125
+ # We'll build a multiline string describing the predictions
126
+ lines = []
127
+ lines.append("Predicted Tags by Category:\n")
128
+ for cat, tags_list in results_by_category.items():
129
+ lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
130
+ # Sort descending by probability
131
+ for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
132
+ lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
133
+ lines.append("") # blank line after each category
134
+ # Join lines with newlines
135
+ verbose_output = "\n".join(lines)
136
+ return verbose_output
137
+
138
+ if __name__ == "__main__":
139
+ result = inference("", output_format="as_prompt")
140
  print(result)
model_code.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
5
+ from PIL import Image
6
+ from typing import Optional
7
+ import torchvision.transforms as transforms
8
+ import os
9
+ import json
10
+
11
+ class InitialOnlyImageTagger(nn.Module):
12
+ """
13
+ A lightweight version of ImageTagger that only includes the backbone and initial classifier.
14
+ This model uses significantly less VRAM than the full model.
15
+ """
16
+ def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
17
+ dropout=0.1, pretrained=True):
18
+ super().__init__()
19
+ # Debug and stats flags
20
+ self._flags = {
21
+ 'debug': False,
22
+ 'model_stats': False
23
+ }
24
+
25
+ # Core model config
26
+ self.dataset = dataset
27
+ self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
28
+
29
+ # Initialize backbone
30
+ if model_name == 'efficientnet_v2_l':
31
+ weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
32
+ self.backbone = efficientnet_v2_l(weights=weights)
33
+ self.backbone.classifier = nn.Identity()
34
+
35
+ # Spatial pooling only - no projection
36
+ self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
37
+
38
+ # Initial tag prediction with bottleneck
39
+ self.initial_classifier = nn.Sequential(
40
+ nn.Linear(self.embedding_dim, self.embedding_dim * 2),
41
+ nn.LayerNorm(self.embedding_dim * 2),
42
+ nn.GELU(),
43
+ nn.Dropout(dropout),
44
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim),
45
+ nn.LayerNorm(self.embedding_dim),
46
+ nn.GELU(),
47
+ nn.Linear(self.embedding_dim, total_tags)
48
+ )
49
+
50
+ # Temperature scaling
51
+ self.temperature = nn.Parameter(torch.ones(1) * 1.5)
52
+
53
+ @property
54
+ def debug(self):
55
+ return self._flags['debug']
56
+
57
+ @debug.setter
58
+ def debug(self, value):
59
+ self._flags['debug'] = value
60
+
61
+ @property
62
+ def model_stats(self):
63
+ return self._flags['model_stats']
64
+
65
+ @model_stats.setter
66
+ def model_stats(self, value):
67
+ self._flags['model_stats'] = value
68
+
69
+ def preprocess_image(self, image_path, image_size=512):
70
+ """Process an image for inference using same preprocessing as training"""
71
+ if not os.path.exists(image_path):
72
+ raise ValueError(f"Image not found at path: {image_path}")
73
+
74
+ # Initialize the same transform used during training
75
+ transform = transforms.Compose([
76
+ transforms.ToTensor(),
77
+ ])
78
+
79
+ try:
80
+ with Image.open(image_path) as img:
81
+ # Convert RGBA or Palette images to RGB
82
+ if img.mode in ('RGBA', 'P'):
83
+ img = img.convert('RGB')
84
+
85
+ # Get original dimensions
86
+ width, height = img.size
87
+ aspect_ratio = width / height
88
+
89
+ # Calculate new dimensions to maintain aspect ratio
90
+ if aspect_ratio > 1:
91
+ new_width = image_size
92
+ new_height = int(new_width / aspect_ratio)
93
+ else:
94
+ new_height = image_size
95
+ new_width = int(new_height * aspect_ratio)
96
+
97
+ # Resize with LANCZOS filter
98
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
99
+
100
+ # Create new image with padding
101
+ new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
102
+ paste_x = (image_size - new_width) // 2
103
+ paste_y = (image_size - new_height) // 2
104
+ new_image.paste(img, (paste_x, paste_y))
105
+
106
+ # Apply transforms (without normalization)
107
+ img_tensor = transform(new_image)
108
+ return img_tensor
109
+ except Exception as e:
110
+ raise Exception(f"Error processing {image_path}: {str(e)}")
111
+
112
+ def forward(self, x):
113
+ """Forward pass with only the initial predictions"""
114
+ # Image Feature Extraction
115
+ features = self.backbone.features(x)
116
+ features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
117
+
118
+ # Initial Tag Predictions
119
+ initial_logits = self.initial_classifier(features)
120
+ initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
121
+
122
+ # For API compatibility with the full model, return the same predictions twice
123
+ return initial_preds, initial_preds
124
+
125
+ def predict(self, image_path, threshold=0.325, category_thresholds=None):
126
+ """
127
+ Run inference on an image with support for category-specific thresholds.
128
+ """
129
+ # Preprocess the image
130
+ img_tensor = self.preprocess_image(image_path).unsqueeze(0)
131
+
132
+ # Move to the same device as model and convert to half precision
133
+ device = next(self.parameters()).device
134
+ dtype = next(self.parameters()).dtype # Match model's precision
135
+ img_tensor = img_tensor.to(device, dtype=dtype)
136
+
137
+ # Run inference
138
+ with torch.no_grad():
139
+ initial_preds, _ = self.forward(img_tensor)
140
+
141
+ # Apply sigmoid to get probabilities
142
+ initial_probs = torch.sigmoid(initial_preds)
143
+
144
+ # Apply thresholds
145
+ if category_thresholds:
146
+ # Create binary prediction tensors
147
+ initial_binary = torch.zeros_like(initial_probs)
148
+
149
+ # Apply thresholds by category
150
+ for category, cat_threshold in category_thresholds.items():
151
+ # Create a mask for tags in this category
152
+ category_mask = torch.zeros_like(initial_probs, dtype=torch.bool)
153
+
154
+ # Find indices for this category
155
+ for tag_idx in range(initial_probs.size(-1)):
156
+ try:
157
+ _, tag_category = self.dataset.get_tag_info(tag_idx)
158
+ if tag_category == category:
159
+ category_mask[:, tag_idx] = True
160
+ except:
161
+ continue
162
+
163
+ # Apply threshold only to tags in this category
164
+ cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
165
+ initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype)
166
+
167
+ predictions = initial_binary
168
+ else:
169
+ # Use the same threshold for all tags
170
+ threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
171
+ predictions = (initial_probs >= threshold_tensor).to(dtype)
172
+
173
+ # Return the same probabilities for both initial and refined for API compatibility
174
+ return {
175
+ 'initial_probabilities': initial_probs,
176
+ 'refined_probabilities': initial_probs, # Same as initial for compatibility
177
+ 'predictions': predictions
178
+ }
179
+
180
+ def get_tags_from_predictions(self, predictions, include_probabilities=True):
181
+ """
182
+ Convert model predictions to human-readable tags grouped by category.
183
+ """
184
+ # Get non-zero predictions
185
+ if predictions.dim() > 1:
186
+ predictions = predictions[0] # Remove batch dimension
187
+
188
+ # Get indices of positive predictions
189
+ indices = torch.where(predictions > 0)[0].cpu().tolist()
190
+
191
+ # Group by category
192
+ result = {}
193
+ for idx in indices:
194
+ tag_name, category = self.dataset.get_tag_info(idx)
195
+
196
+ if category not in result:
197
+ result[category] = []
198
+
199
+ if include_probabilities:
200
+ prob = predictions[idx].item()
201
+ result[category].append((tag_name, prob))
202
+ else:
203
+ result[category].append(tag_name)
204
+
205
+ # Sort tags by probability within each category
206
+ if include_probabilities:
207
+ for category in result:
208
+ result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
209
+
210
+ return result
211
+
212
+ class FlashAttention(nn.Module):
213
+ def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True):
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.num_heads = num_heads
217
+ self.dropout = dropout
218
+ self.batch_first = batch_first
219
+ self.head_dim = dim // num_heads
220
+ assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
221
+
222
+ self.q_proj = nn.Linear(dim, dim, bias=False)
223
+ self.k_proj = nn.Linear(dim, dim, bias=False)
224
+ self.v_proj = nn.Linear(dim, dim, bias=False)
225
+ self.out_proj = nn.Linear(dim, dim, bias=False)
226
+
227
+ for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
228
+ nn.init.xavier_uniform_(proj.weight, gain=0.1)
229
+
230
+ self.scale = self.head_dim ** -0.5
231
+ self.debug = False
232
+
233
+ def _debug_print(self, name, tensor):
234
+ """Debug helper"""
235
+ if self.debug:
236
+ print(f"\n{name}:")
237
+ print(f"Shape: {tensor.shape}")
238
+ print(f"Device: {tensor.device}")
239
+ print(f"Dtype: {tensor.dtype}")
240
+ if tensor.is_floating_point():
241
+ with torch.no_grad():
242
+ print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
243
+ print(f"Mean: {tensor.mean().item():.3f}")
244
+ print(f"Std: {tensor.std().item():.3f}")
245
+
246
+ def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor:
247
+ """Reshape input tensor for flash attention format"""
248
+ batch_size, seq_len, _ = x.size()
249
+ x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
250
+ x = x.transpose(1, 2) # [B, H, S, D]
251
+ return x.contiguous()
252
+
253
+ def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None,
254
+ value: Optional[torch.Tensor] = None,
255
+ mask: Optional[torch.Tensor] = None) -> torch.Tensor:
256
+ """Forward pass with flash attention"""
257
+ if self.debug:
258
+ print("\nFlashAttention Forward Pass")
259
+
260
+ batch_size = query.size(0)
261
+
262
+ # Use query as key/value if not provided
263
+ key = query if key is None else key
264
+ value = query if value is None else value
265
+
266
+ # Project inputs
267
+ q = self.q_proj(query)
268
+ k = self.k_proj(key)
269
+ v = self.v_proj(value)
270
+
271
+ if self.debug:
272
+ self._debug_print("Query before reshape", q)
273
+
274
+ # Reshape for attention [B, H, S, D]
275
+ q = self._reshape_for_flash(q)
276
+ k = self._reshape_for_flash(k)
277
+ v = self._reshape_for_flash(v)
278
+
279
+ if self.debug:
280
+ self._debug_print("Query after reshape", q)
281
+
282
+ # Handle masking
283
+ if mask is not None:
284
+ # First convert mask to proper shape based on input dimensionality
285
+ if mask.dim() == 2: # [B, S]
286
+ mask = mask.view(batch_size, 1, -1, 1)
287
+ elif mask.dim() == 3: # [B, S, S]
288
+ mask = mask.view(batch_size, 1, mask.size(1), mask.size(2))
289
+ elif mask.dim() == 5: # [B, 1, S, S, S]
290
+ mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3))
291
+
292
+ # Ensure mask is float16 if we're using float16
293
+ mask = mask.to(q.dtype)
294
+
295
+ if self.debug:
296
+ self._debug_print("Prepared mask", mask)
297
+ print(f"q shape: {q.shape}, mask shape: {mask.shape}")
298
+
299
+ # Create attention mask that covers the full sequence length
300
+ seq_len = q.size(2)
301
+ if mask.size(-1) != seq_len:
302
+ # Pad or trim mask to match sequence length
303
+ new_mask = torch.zeros(batch_size, 1, seq_len, seq_len,
304
+ device=mask.device, dtype=mask.dtype)
305
+ min_len = min(seq_len, mask.size(-1))
306
+ new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len]
307
+ mask = new_mask
308
+
309
+ # Create key padding mask
310
+ key_padding_mask = mask.squeeze(1).sum(-1) > 0
311
+ key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1)
312
+
313
+ # Apply the key padding mask
314
+ k = k * key_padding_mask
315
+ v = v * key_padding_mask
316
+
317
+ if self.debug:
318
+ self._debug_print("Query before attention", q)
319
+ self._debug_print("Key before attention", k)
320
+ self._debug_print("Value before attention", v)
321
+
322
+ # Run flash attention
323
+ dropout_p = self.dropout if self.training else 0.0
324
+ output = flash_attn_func(
325
+ q, k, v,
326
+ dropout_p=dropout_p,
327
+ softmax_scale=self.scale,
328
+ causal=False
329
+ )
330
+
331
+ if self.debug:
332
+ self._debug_print("Output after attention", output)
333
+
334
+ # Reshape output [B, H, S, D] -> [B, S, H, D] -> [B, S, D]
335
+ output = output.transpose(1, 2).contiguous()
336
+ output = output.view(batch_size, -1, self.dim)
337
+
338
+ # Final projection
339
+ output = self.out_proj(output)
340
+
341
+ if self.debug:
342
+ self._debug_print("Final output", output)
343
+
344
+ return output
345
+
346
+ class OptimizedTagEmbedding(nn.Module):
347
+ def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1):
348
+ super().__init__()
349
+ # Single shared embedding for all tags
350
+ self.embedding = nn.Embedding(num_tags, embedding_dim)
351
+ self.attention = FlashAttention(embedding_dim, num_heads, dropout)
352
+ self.norm1 = nn.LayerNorm(embedding_dim)
353
+ self.norm2 = nn.LayerNorm(embedding_dim)
354
+
355
+ # Single importance weighting for all tags
356
+ self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1)
357
+
358
+ # Projection layers for unified tag context
359
+ self.context_proj = nn.Sequential(
360
+ nn.Linear(embedding_dim, embedding_dim * 2),
361
+ nn.LayerNorm(embedding_dim * 2),
362
+ nn.GELU(),
363
+ nn.Dropout(dropout),
364
+ nn.Linear(embedding_dim * 2, embedding_dim),
365
+ nn.LayerNorm(embedding_dim)
366
+ )
367
+
368
+ self.importance_scale = nn.Parameter(torch.tensor(0.1))
369
+ self.context_scale = nn.Parameter(torch.tensor(1.0))
370
+ self.debug = False
371
+
372
+ def _debug_print(self, name, tensor, extra_info=None):
373
+ """Memory efficient debug printing with type handling"""
374
+ if self.debug:
375
+ print(f"\n{name}:")
376
+ print(f"- Shape: {tensor.shape}")
377
+ if isinstance(tensor, torch.Tensor):
378
+ with torch.no_grad():
379
+ print(f"- Device: {tensor.device}")
380
+ print(f"- Dtype: {tensor.dtype}")
381
+
382
+ # Convert to float32 for statistics if needed
383
+ if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
384
+ calc_tensor = tensor.float()
385
+ else:
386
+ calc_tensor = tensor
387
+
388
+ try:
389
+ min_val = calc_tensor.min().item()
390
+ max_val = calc_tensor.max().item()
391
+ mean_val = calc_tensor.mean().item()
392
+ std_val = calc_tensor.std().item()
393
+ norm_val = torch.norm(calc_tensor).item()
394
+
395
+ print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]")
396
+ print(f"- Mean: {mean_val:.3f}")
397
+ print(f"- Std: {std_val:.3f}")
398
+ print(f"- L2 Norm: {norm_val:.3f}")
399
+
400
+ if extra_info:
401
+ print(f"- Additional info: {extra_info}")
402
+ except Exception as e:
403
+ print(f"- Could not compute statistics: {str(e)}")
404
+
405
+ def _debug_tensor(self, name, tensor):
406
+ """Debug helper with dtype-specific analysis"""
407
+ if self.debug and isinstance(tensor, torch.Tensor):
408
+ print(f"\n{name}:")
409
+ print(f"- Shape: {tensor.shape}")
410
+ print(f"- Device: {tensor.device}")
411
+ print(f"- Dtype: {tensor.dtype}")
412
+ with torch.no_grad():
413
+ has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False
414
+ has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False
415
+ print(f"- Contains NaN: {has_nan}")
416
+ print(f"- Contains Inf: {has_inf}")
417
+
418
+ # Different stats for different dtypes
419
+ if tensor.is_floating_point():
420
+ print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
421
+ print(f"- Mean: {tensor.mean().item():.3f}")
422
+ print(f"- Std: {tensor.std().item():.3f}")
423
+ else:
424
+ # For integer tensors
425
+ print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]")
426
+ print(f"- Unique values: {tensor.unique().numel()}")
427
+
428
+ def _process_category(self, indices, masks):
429
+ """Process a single category of tags"""
430
+ # Get embeddings for this category
431
+ embeddings = self.embedding(indices)
432
+
433
+ if self.debug:
434
+ self._debug_tensor("Category embeddings", embeddings)
435
+
436
+ # Apply importance weights
437
+ importance = torch.sigmoid(self.tag_importance) * self.importance_scale
438
+ importance = torch.clamp(importance, min=0.01, max=10.0)
439
+ importance_weights = importance[indices].unsqueeze(-1)
440
+
441
+ # Apply and normalize
442
+ embeddings = embeddings * importance_weights
443
+ embeddings = self.norm1(embeddings)
444
+
445
+ # Apply attention if we have more than one tag
446
+ if embeddings.size(1) > 1:
447
+ if masks is not None:
448
+ attention_mask = torch.einsum('bi,bj->bij', masks, masks)
449
+ attended = self.attention(embeddings, mask=attention_mask)
450
+ else:
451
+ attended = self.attention(embeddings)
452
+ embeddings = self.norm2(attended)
453
+
454
+ # Pool embeddings with masking
455
+ if masks is not None:
456
+ masked_embeddings = embeddings * masks.unsqueeze(-1)
457
+ pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0)
458
+ else:
459
+ pooled = embeddings.mean(dim=1)
460
+
461
+ return pooled, embeddings
462
+
463
+ def forward(self, tag_indices_dict, tag_masks_dict=None):
464
+ """
465
+ Process all tags in a unified embedding space
466
+ Args:
467
+ tag_indices_dict: dict of {category: tensor of indices}
468
+ tag_masks_dict: dict of {category: tensor of masks}
469
+ """
470
+ if self.debug:
471
+ print("\nOptimizedTagEmbedding Forward Pass")
472
+
473
+ # Concatenate all indices and masks
474
+ all_indices = []
475
+ all_masks = []
476
+ batch_size = None
477
+
478
+ for category, indices in tag_indices_dict.items():
479
+ if batch_size is None:
480
+ batch_size = indices.size(0)
481
+ all_indices.append(indices)
482
+ if tag_masks_dict:
483
+ all_masks.append(tag_masks_dict[category])
484
+
485
+ # Stack along sequence dimension
486
+ combined_indices = torch.cat(all_indices, dim=1) # [B, total_seq_len]
487
+ if tag_masks_dict:
488
+ combined_masks = torch.cat(all_masks, dim=1) # [B, total_seq_len]
489
+
490
+ if self.debug:
491
+ self._debug_tensor("Combined indices", combined_indices)
492
+ if tag_masks_dict:
493
+ self._debug_tensor("Combined masks", combined_masks)
494
+
495
+ # Get embeddings for all tags using shared embedding
496
+ embeddings = self.embedding(combined_indices) # [B, total_seq_len, D]
497
+
498
+ if self.debug:
499
+ self._debug_tensor("Base embeddings", embeddings)
500
+
501
+ # Apply unified importance weighting
502
+ importance = torch.sigmoid(self.tag_importance) * self.importance_scale
503
+ importance = torch.clamp(importance, min=0.01, max=10.0)
504
+ importance_weights = importance[combined_indices].unsqueeze(-1)
505
+
506
+ # Apply and normalize importance weights
507
+ embeddings = embeddings * importance_weights
508
+ embeddings = self.norm1(embeddings)
509
+
510
+ if self.debug:
511
+ self._debug_tensor("Weighted embeddings", embeddings)
512
+
513
+ # Apply attention across all tags together
514
+ if tag_masks_dict:
515
+ attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks)
516
+ attended = self.attention(embeddings, mask=attention_mask)
517
+ else:
518
+ attended = self.attention(embeddings)
519
+
520
+ attended = self.norm2(attended)
521
+
522
+ if self.debug:
523
+ self._debug_tensor("Attended embeddings", attended)
524
+
525
+ # Global pooling with masking
526
+ if tag_masks_dict:
527
+ masked_embeddings = attended * combined_masks.unsqueeze(-1)
528
+ tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0)
529
+ else:
530
+ tag_context = attended.mean(dim=1)
531
+
532
+ # Project and scale context
533
+ tag_context = self.context_proj(tag_context)
534
+ context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0)
535
+ tag_context = tag_context * context_scale
536
+
537
+ if self.debug:
538
+ self._debug_tensor("Final tag context", tag_context)
539
+
540
+ return tag_context, attended
541
+
542
+ class TagDataset:
543
+ """Lightweight dataset wrapper for inference only"""
544
+ def __init__(self, total_tags, idx_to_tag, tag_to_category):
545
+ self.total_tags = total_tags
546
+ self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()}
547
+ self.tag_to_category = tag_to_category
548
+
549
+ def get_tag_info(self, idx):
550
+ """Get tag name and category for a given index"""
551
+ tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}")
552
+ category = self.tag_to_category.get(tag_name, "general")
553
+ return tag_name, category
554
+
555
+ class ImageTagger(nn.Module):
556
+ def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
557
+ num_heads=16, dropout=0.1, pretrained=True,
558
+ tag_context_size=256):
559
+ super().__init__()
560
+ # Debug and stats flags
561
+ self._flags = {
562
+ 'debug': False,
563
+ 'model_stats': False
564
+ }
565
+
566
+ # Core model config
567
+ self.dataset = dataset
568
+ self.tag_context_size = tag_context_size
569
+ self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
570
+
571
+ # Initialize backbone
572
+ if model_name == 'efficientnet_v2_l':
573
+ weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
574
+ self.backbone = efficientnet_v2_l(weights=weights)
575
+ self.backbone.classifier = nn.Identity()
576
+
577
+ # Spatial pooling only - no projection
578
+ self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
579
+
580
+ # Initial tag prediction with bottleneck
581
+ self.initial_classifier = nn.Sequential(
582
+ nn.Linear(self.embedding_dim, self.embedding_dim * 2),
583
+ nn.LayerNorm(self.embedding_dim * 2),
584
+ nn.GELU(),
585
+ nn.Dropout(dropout),
586
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim),
587
+ nn.LayerNorm(self.embedding_dim),
588
+ nn.GELU(),
589
+ nn.Linear(self.embedding_dim, total_tags)
590
+ )
591
+
592
+ # Tag embeddings at full dimension
593
+ self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
594
+ self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
595
+ self.tag_norm = nn.LayerNorm(self.embedding_dim)
596
+
597
+ # Improved cross attention projection
598
+ self.cross_proj = nn.Sequential(
599
+ nn.Linear(self.embedding_dim, self.embedding_dim * 2),
600
+ nn.LayerNorm(self.embedding_dim * 2),
601
+ nn.GELU(),
602
+ nn.Dropout(dropout),
603
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim)
604
+ )
605
+
606
+ # Cross attention at full dimension
607
+ self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
608
+ self.cross_norm = nn.LayerNorm(self.embedding_dim)
609
+
610
+ # Refined classifier with improved bottleneck
611
+ self.refined_classifier = nn.Sequential(
612
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2), # Doubled input size for residual
613
+ nn.LayerNorm(self.embedding_dim * 2),
614
+ nn.GELU(),
615
+ nn.Dropout(dropout),
616
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim),
617
+ nn.LayerNorm(self.embedding_dim),
618
+ nn.GELU(),
619
+ nn.Linear(self.embedding_dim, total_tags)
620
+ )
621
+
622
+ # Temperature scaling
623
+ self.temperature = nn.Parameter(torch.ones(1) * 1.5)
624
+
625
+ def _get_selected_tags(self, logits):
626
+ """Select top-K tags based on prediction confidence"""
627
+ # Apply sigmoid to get probabilities
628
+ probs = torch.sigmoid(logits)
629
+
630
+ # Get top-K predictions for each image in batch
631
+ batch_size = logits.size(0)
632
+ topk_values, topk_indices = torch.topk(
633
+ probs, k=self.tag_context_size, dim=1, largest=True, sorted=True
634
+ )
635
+
636
+ return topk_indices, topk_values
637
+
638
+ @property
639
+ def debug(self):
640
+ return self._flags['debug']
641
+
642
+ @debug.setter
643
+ def debug(self, value):
644
+ self._flags['debug'] = value
645
+
646
+ @property
647
+ def model_stats(self):
648
+ return self._flags['model_stats']
649
+
650
+ @model_stats.setter
651
+ def model_stats(self, value):
652
+ self._flags['model_stats'] = value
653
+
654
+ def preprocess_image(self, image_path, image_size=512):
655
+ """Process an image for inference using same preprocessing as training"""
656
+ if not os.path.exists(image_path):
657
+ raise ValueError(f"Image not found at path: {image_path}")
658
+
659
+ # Initialize the same transform used during training
660
+ transform = transforms.Compose([
661
+ transforms.ToTensor(),
662
+ ])
663
+
664
+ try:
665
+ with Image.open(image_path) as img:
666
+ # Convert RGBA or Palette images to RGB
667
+ if img.mode in ('RGBA', 'P'):
668
+ img = img.convert('RGB')
669
+
670
+ # Get original dimensions
671
+ width, height = img.size
672
+ aspect_ratio = width / height
673
+
674
+ # Calculate new dimensions to maintain aspect ratio
675
+ if aspect_ratio > 1:
676
+ new_width = image_size
677
+ new_height = int(new_width / aspect_ratio)
678
+ else:
679
+ new_height = image_size
680
+ new_width = int(new_height * aspect_ratio)
681
+
682
+ # Resize with LANCZOS filter
683
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
684
+
685
+ # Create new image with padding
686
+ new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
687
+ paste_x = (image_size - new_width) // 2
688
+ paste_y = (image_size - new_height) // 2
689
+ new_image.paste(img, (paste_x, paste_y))
690
+
691
+ # Apply transforms (without normalization)
692
+ img_tensor = transform(new_image)
693
+ return img_tensor
694
+ except Exception as e:
695
+ raise Exception(f"Error processing {image_path}: {str(e)}")
696
+
697
+ def forward(self, x):
698
+ """Forward pass with simplified feature handling"""
699
+ # Initialize tracking dicts
700
+ model_stats = {} if self.model_stats else {}
701
+ debug_tensors = {} if self.debug else None
702
+
703
+ # 1. Image Feature Extraction
704
+ features = self.backbone.features(x)
705
+ features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
706
+
707
+ # 2. Initial Tag Predictions
708
+ initial_logits = self.initial_classifier(features)
709
+ initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
710
+
711
+ # 3. Tag Selection & Embedding (simplified)
712
+ pred_tag_indices, _ = self._get_selected_tags(initial_preds)
713
+ tag_embeddings = self.tag_embedding(pred_tag_indices)
714
+
715
+ # 4. Self-Attention on Tags
716
+ attended_tags = self.tag_attention(tag_embeddings)
717
+ attended_tags = self.tag_norm(attended_tags)
718
+
719
+ # 5. Cross-Attention between Features and Tags
720
+ features_proj = self.cross_proj(features)
721
+ features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1)
722
+
723
+ cross_attended = self.cross_attention(features_expanded, attended_tags)
724
+ cross_attended = self.cross_norm(cross_attended)
725
+
726
+ # 6. Feature Fusion with Residual Connection
727
+ fused_features = cross_attended.mean(dim=1) # Average across tag dimension
728
+ # Concatenate original and attended features
729
+ combined_features = torch.cat([features, fused_features], dim=-1)
730
+
731
+ # 7. Refined Predictions
732
+ refined_logits = self.refined_classifier(combined_features)
733
+ refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
734
+
735
+ # Return both prediction sets
736
+ return initial_preds, refined_preds
737
+
738
+ def predict(self, image_path, threshold=0.325, category_thresholds=None):
739
+ """
740
+ Run inference on an image with support for category-specific thresholds.
741
+ """
742
+ # Preprocess the image
743
+ img_tensor = self.preprocess_image(image_path).unsqueeze(0)
744
+
745
+ # Move to the same device as model and convert to half precision
746
+ device = next(self.parameters()).device
747
+ dtype = next(self.parameters()).dtype # Match model's precision
748
+ img_tensor = img_tensor.to(device, dtype=dtype)
749
+
750
+ # Run inference
751
+ with torch.no_grad():
752
+ initial_preds, refined_preds = self.forward(img_tensor)
753
+
754
+ # Apply sigmoid to get probabilities
755
+ initial_probs = torch.sigmoid(initial_preds)
756
+ refined_probs = torch.sigmoid(refined_preds)
757
+
758
+ # Apply thresholds
759
+ if category_thresholds:
760
+ # Create binary prediction tensors
761
+ refined_binary = torch.zeros_like(refined_probs)
762
+
763
+ # Apply thresholds by category
764
+ for category, cat_threshold in category_thresholds.items():
765
+ # Create a mask for tags in this category
766
+ category_mask = torch.zeros_like(refined_probs, dtype=torch.bool)
767
+
768
+ # Find indices for this category
769
+ for tag_idx in range(refined_probs.size(-1)):
770
+ try:
771
+ _, tag_category = self.dataset.get_tag_info(tag_idx)
772
+ if tag_category == category:
773
+ category_mask[:, tag_idx] = True
774
+ except:
775
+ continue
776
+
777
+ # Apply threshold only to tags in this category - ensure dtype consistency
778
+ cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
779
+ refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype)
780
+
781
+ predictions = refined_binary
782
+ else:
783
+ # Use the same threshold for all tags
784
+ threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
785
+ predictions = (refined_probs >= threshold_tensor).to(dtype)
786
+
787
+ # Return both probabilities and thresholded predictions
788
+ return {
789
+ 'initial_probabilities': initial_probs,
790
+ 'refined_probabilities': refined_probs,
791
+ 'predictions': predictions
792
+ }
793
+
794
+ def get_tags_from_predictions(self, predictions, include_probabilities=True):
795
+ """
796
+ Convert model predictions to human-readable tags grouped by category.
797
+ """
798
+ # Get non-zero predictions
799
+ if predictions.dim() > 1:
800
+ predictions = predictions[0] # Remove batch dimension
801
+
802
+ # Get indices of positive predictions
803
+ indices = torch.where(predictions > 0)[0].cpu().tolist()
804
+
805
+ # Group by category
806
+ result = {}
807
+ for idx in indices:
808
+ tag_name, category = self.dataset.get_tag_info(idx)
809
+
810
+ if category not in result:
811
+ result[category] = []
812
+
813
+ if include_probabilities:
814
+ prob = predictions[idx].item()
815
+ result[category].append((tag_name, prob))
816
+ else:
817
+ result[category].append(tag_name)
818
+
819
+ # Sort tags by probability within each category
820
+ if include_probabilities:
821
+ for category in result:
822
+ result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
823
+
824
+ return result
825
+
826
+ def load_model(model_dir, device='cuda'):
827
+ """Load model with better error handling and warnings"""
828
+ print(f"Loading model from {model_dir}")
829
+
830
+ try:
831
+ # Load metadata
832
+ metadata_path = os.path.join(model_dir, "metadata.json")
833
+ if not os.path.exists(metadata_path):
834
+ raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
835
+
836
+ with open(metadata_path, 'r') as f:
837
+ metadata = json.load(f)
838
+
839
+ # Load model info
840
+ model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
841
+ if os.path.exists(model_info_path):
842
+ with open(model_info_path, 'r') as f:
843
+ model_info = json.load(f)
844
+ else:
845
+ print("WARNING: Model info file not found, using default settings")
846
+ model_info = {
847
+ "tag_context_size": 256,
848
+ "num_heads": 16,
849
+ "precision": "float16"
850
+ }
851
+
852
+ # Create dataset wrapper
853
+ dataset = TagDataset(
854
+ total_tags=metadata['total_tags'],
855
+ idx_to_tag=metadata['idx_to_tag'],
856
+ tag_to_category=metadata['tag_to_category']
857
+ )
858
+
859
+ # Initialize model with exact settings from model_info
860
+ model = ImageTagger(
861
+ total_tags=metadata['total_tags'],
862
+ dataset=dataset,
863
+ num_heads=model_info.get('num_heads', 16),
864
+ tag_context_size=model_info.get('tag_context_size', 256),
865
+ pretrained=False
866
+ )
867
+
868
+ # Load weights
869
+ state_dict_path = os.path.join(model_dir, "model.pt")
870
+ if not os.path.exists(state_dict_path):
871
+ raise FileNotFoundError(f"Model state dict not found at {state_dict_path}")
872
+
873
+ state_dict = torch.load(state_dict_path, map_location=device)
874
+
875
+ # First try strict loading
876
+ try:
877
+ model.load_state_dict(state_dict, strict=True)
878
+ print("✓ Model state dict loaded with strict=True successfully")
879
+ except Exception as e:
880
+ print(f"! Strict loading failed: {str(e)}")
881
+ print("Attempting non-strict loading...")
882
+
883
+ # Try non-strict loading
884
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
885
+
886
+ print(f"Non-strict loading completed with:")
887
+ print(f"- {len(missing_keys)} missing keys")
888
+ print(f"- {len(unexpected_keys)} unexpected keys")
889
+
890
+ if len(missing_keys) > 0:
891
+ print(f"Sample missing keys: {missing_keys[:5]}")
892
+ if len(unexpected_keys) > 0:
893
+ print(f"Sample unexpected keys: {unexpected_keys[:5]}")
894
+
895
+ # Move model to device
896
+ model = model.to(device)
897
+
898
+ # Set to half precision if needed
899
+ if model_info.get('precision') == 'float16':
900
+ model = model.half()
901
+ print("✓ Model converted to half precision")
902
+
903
+ # Set to eval mode
904
+ model.eval()
905
+ print("✓ Model set to evaluation mode")
906
+
907
+ # Verify parameter dtype
908
+ param_dtype = next(model.parameters()).dtype
909
+ print(f"✓ Model loaded with precision: {param_dtype}")
910
+
911
+ return model, dataset
912
+
913
+ except Exception as e:
914
+ print(f"ERROR loading model: {str(e)}")
915
+ import traceback
916
+ traceback.print_exc()
917
+ raise
918
+
919
+ # Example usage
920
+ if __name__ == "__main__":
921
+ import sys
922
+
923
+ # Get model directory from command line or use default
924
+ model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model"
925
+
926
+ # Load model
927
+ model, dataset, thresholds = load_model(model_dir)
928
+
929
+ # Display info
930
+ print(f"\nModel information:")
931
+ print(f" Total tags: {dataset.total_tags}")
932
+ print(f" Device: {next(model.parameters()).device}")
933
+ print(f" Precision: {next(model.parameters()).dtype}")
934
+
935
+ # Test on an image if provided
936
+ if len(sys.argv) > 2:
937
+ image_path = sys.argv[2]
938
+ print(f"\nRunning inference on {image_path}")
939
+
940
+ # Use category thresholds if available
941
+ if thresholds and 'categories' in thresholds:
942
+ category_thresholds = {cat: opt['balanced']['threshold']
943
+ for cat, opt in thresholds['categories'].items()}
944
+ results = model.predict(image_path, category_thresholds=category_thresholds)
945
+ else:
946
+ results = model.predict(image_path)
947
+
948
+ # Get tags
949
+ tags = model.get_tags_from_predictions(results['predictions'])
950
+
951
+ # Print tags by category
952
+ print("\nPredicted tags:")
953
+ for category, category_tags in tags.items():
954
+ print(f"\n{category.capitalize()}:")
955
+ for tag, prob in category_tags:
956
+ print(f" {tag}: {prob:.3f}")
model_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "class_name": "ImageTagger",
3
+ "args": {
4
+ "total_tags": 70527,
5
+ "num_heads": 16,
6
+ "dropout": 0.1,
7
+ "tag_context_size": 256
8
+ }
9
+ }
model_info_initial_only.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "precision": "float16",
3
+ "tag_context_size": 256,
4
+ "num_heads": 16,
5
+ "architecture": "ImageTagger",
6
+ "embedding_dim": 1280,
7
+ "backbone": "efficientnet_v2_l",
8
+ "model_type": "initial_only"
9
+ }
model_no_flash.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
5
+
6
+ class MultiheadAttentionNoFlash(nn.Module):
7
+ """Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
8
+ def __init__(self, dim, num_heads=8, dropout=0.0):
9
+ super().__init__()
10
+ assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
11
+ self.dim = dim
12
+ self.num_heads = num_heads
13
+ self.head_dim = dim // num_heads
14
+ self.scale = self.head_dim ** -0.5 # scaling factor for dot-product attention
15
+
16
+ # Define separate projections for query, key, value, and output (no biases to match FlashAttention)
17
+ self.q_proj = nn.Linear(dim, dim, bias=False)
18
+ self.k_proj = nn.Linear(dim, dim, bias=False)
19
+ self.v_proj = nn.Linear(dim, dim, bias=False)
20
+ self.out_proj = nn.Linear(dim, dim, bias=False)
21
+ # (Note: We omit dropout in attention computation for ONNX simplicity; model should be set to eval mode anyway.)
22
+
23
+ def forward(self, query, key=None, value=None):
24
+ # Allow usage as self-attention if key/value not provided
25
+ if key is None:
26
+ key = query
27
+ if value is None:
28
+ value = key
29
+
30
+ # Linear projections
31
+ Q = self.q_proj(query) # [B, S_q, dim]
32
+ K = self.k_proj(key) # [B, S_k, dim]
33
+ V = self.v_proj(value) # [B, S_v, dim]
34
+
35
+ # Reshape into (B, num_heads, S, head_dim) for computing attention per head
36
+ B, S_q, _ = Q.shape
37
+ _, S_k, _ = K.shape
38
+ Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_q, head_dim]
39
+ K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
40
+ V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
41
+
42
+ # Scaled dot-product attention: compute attention weights
43
+ attn_weights = torch.matmul(Q, K.transpose(2, 3)) # [B, heads, S_q, S_k]
44
+ attn_weights = attn_weights * self.scale
45
+ attn_probs = F.softmax(attn_weights, dim=-1) # softmax over S_k (key length)
46
+
47
+ # Apply attention weights to values
48
+ attn_output = torch.matmul(attn_probs, V) # [B, heads, S_q, head_dim]
49
+
50
+ # Reshape back to [B, S_q, dim]
51
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
52
+ # Output projection
53
+ output = self.out_proj(attn_output) # [B, S_q, dim]
54
+ return output
55
+
56
+ class ImageTaggerRefinedONNX(nn.Module):
57
+ """
58
+ Refined CAMIE Image Tagger model without FlashAttention.
59
+ - EfficientNetV2 backbone
60
+ - Initial classifier for preliminary tag logits
61
+ - Multi-head self-attention on top predicted tag embeddings
62
+ - Multi-head cross-attention between image feature and tag embeddings
63
+ - Refined classifier for final tag logits
64
+ """
65
+ def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
66
+ super().__init__()
67
+ self.tag_context_size = tag_context_size
68
+ self.embedding_dim = 1280 # EfficientNetV2-L feature dimension
69
+
70
+ # Backbone feature extractor (EfficientNetV2-L)
71
+ backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
72
+ backbone.classifier = nn.Identity() # remove final classification head
73
+ self.backbone = backbone
74
+
75
+ # Spatial pooling to get a single feature vector per image (1x1 avg pool)
76
+ self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
77
+
78
+ # Initial classifier (two-layer MLP) to predict tags from image feature
79
+ self.initial_classifier = nn.Sequential(
80
+ nn.Linear(self.embedding_dim, self.embedding_dim * 2),
81
+ nn.LayerNorm(self.embedding_dim * 2),
82
+ nn.GELU(),
83
+ nn.Dropout(dropout),
84
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim),
85
+ nn.LayerNorm(self.embedding_dim),
86
+ nn.GELU(),
87
+ nn.Linear(self.embedding_dim, total_tags) # outputs raw logits for all tags
88
+ )
89
+
90
+ # Embedding for tags (each tag gets an embedding vector, used for attention)
91
+ self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
92
+
93
+ # Self-attention over the selected tag embeddings (replaces FlashAttention)
94
+ self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
95
+ self.tag_norm = nn.LayerNorm(self.embedding_dim)
96
+
97
+ # Projection from image feature to query vector for cross-attention
98
+ self.cross_proj = nn.Sequential(
99
+ nn.Linear(self.embedding_dim, self.embedding_dim * 2),
100
+ nn.LayerNorm(self.embedding_dim * 2),
101
+ nn.GELU(),
102
+ nn.Dropout(dropout),
103
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim)
104
+ )
105
+ # Cross-attention between image feature (as query) and tag features (as key/value)
106
+ self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
107
+ self.cross_norm = nn.LayerNorm(self.embedding_dim)
108
+
109
+ # Refined classifier (takes concatenated original & attended features)
110
+ self.refined_classifier = nn.Sequential(
111
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
112
+ nn.LayerNorm(self.embedding_dim * 2),
113
+ nn.GELU(),
114
+ nn.Dropout(dropout),
115
+ nn.Linear(self.embedding_dim * 2, self.embedding_dim),
116
+ nn.LayerNorm(self.embedding_dim),
117
+ nn.GELU(),
118
+ nn.Linear(self.embedding_dim, total_tags)
119
+ )
120
+
121
+ # Temperature parameter for scaling logits (to calibrate confidence)
122
+ self.temperature = nn.Parameter(torch.ones(1) * 1.5)
123
+
124
+ def forward(self, images):
125
+ # 1. Feature extraction
126
+ feats = self.backbone.features(images) # [B, 1280, H/32, W/32] features
127
+ feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1) # [B, 1280] global feature vector per image
128
+
129
+ # 2. Initial tag prediction
130
+ initial_logits = self.initial_classifier(feats) # [B, total_tags]
131
+ # Scale by temperature and clamp (to stabilize extreme values, as in original)
132
+ initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
133
+
134
+ # 3. Select top-k predicted tags for context (tag_context_size)
135
+ probs = torch.sigmoid(initial_preds) # convert logits to probabilities
136
+ # Get indices of top `tag_context_size` tags for each sample
137
+ _, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
138
+ # 4. Embed selected tags
139
+ tag_embeds = self.tag_embedding(topk_indices) # [B, tag_context_size, embedding_dim]
140
+
141
+ # 5. Self-attention on tag embeddings (to refine tag representation)
142
+ attn_tags = self.tag_attention(tag_embeds) # [B, tag_context_size, embedding_dim]
143
+ attn_tags = self.tag_norm(attn_tags) # layer norm
144
+
145
+ # 6. Cross-attention between image feature and attended tags
146
+ # Expand image features to have one per tag position
147
+ feat_q = self.cross_proj(feats) # [B, embedding_dim]
148
+ # Repeat each image feature vector tag_context_size times to form a sequence
149
+ feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1) # [B, tag_context_size, embedding_dim]
150
+ # Use image features as queries, tag embeddings as keys and values
151
+ cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags) # [B, tag_context_size, embedding_dim]
152
+ cross_attn = self.cross_norm(cross_attn)
153
+
154
+ # 7. Fuse features: average the cross-attended tag outputs, and combine with original features
155
+ fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim]
156
+ combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2]
157
+
158
+ # 8. Refined tag prediction
159
+ refined_logits = self.refined_classifier(combined) # [B, total_tags]
160
+ refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
161
+
162
+ return initial_preds, refined_preds
163
+
164
+ # --- Load the pretrained refined model weights ---
165
+ total_tags = 70527 # total number of tags in the dataset (Danbooru 2024)
166
+ from safetensors.torch import load_file
167
+ safetensors_path = 'model_refined.safetensors'
168
+ state_dict = load_file(safetensors_path, device='cpu') # Load the saved weights (should be an OrderedDict)
169
+ #state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict)
170
+
171
+ # Initialize our model and load weights
172
+ model = ImageTaggerRefinedONNX(total_tags=total_tags)
173
+ model.load_state_dict(state_dict)
174
+ model.eval() # set to evaluation mode (disable dropout)
175
+
176
+ # (Optional) Cast to float32 if weights were in half precision
177
+ # model = model.float()
178
+
179
+ # --- Export to ONNX ---
180
+ dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False) # dummy batch of 1 image (3x512x512)
181
+ output_onnx_file = "camie_refined_no_flash_v15.onnx"
182
+ torch.onnx.export(
183
+ model, dummy_input, output_onnx_file,
184
+ export_params=True, # store trained parameter weights inside the model file
185
+ opset_version=17, # ONNX opset version (ensure support for needed ops)
186
+ do_constant_folding=True, # optimize constant expressions
187
+ input_names=["image"],
188
+ output_names=["initial_tags", "refined_tags"],
189
+ dynamic_axes={ # set batch dimension to be dynamic
190
+ "image": {0: "batch"},
191
+ "initial_tags": {0: "batch"},
192
+ "refined_tags": {0: "batch"}
193
+ }
194
+ )
195
+ print(f"ONNX model exported to {output_onnx_file}")
thresholds.json ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "overall": {
3
+ "balanced": {
4
+ "threshold": 0.3285714089870453,
5
+ "f1": 0.6128875755303665,
6
+ "precision": 0.6348684210526315,
7
+ "recall": 0.5923778668258164
8
+ },
9
+ "high_precision": {
10
+ "threshold": 0.48367345333099365,
11
+ "f1": 0.5073781135639239,
12
+ "precision": 0.8244772683675426,
13
+ "recall": 0.3664421519311109
14
+ },
15
+ "high_recall": {
16
+ "threshold": 0.20612245798110962,
17
+ "f1": 0.5140483341286104,
18
+ "precision": 0.38317013976064945,
19
+ "recall": 0.7807144684116293
20
+ }
21
+ },
22
+ "weighted": {
23
+ "f1": {
24
+ "threshold": 0.31224489212036133,
25
+ "value": 0.666115043816508
26
+ }
27
+ },
28
+ "categories": {
29
+ "copyright": {
30
+ "balanced": {
31
+ "threshold": 0.3857142925262451,
32
+ "f1": 0.7885196374622356,
33
+ "precision": 0.903114186851211,
34
+ "recall": 0.6997319034852547
35
+ },
36
+ "high_precision": {
37
+ "threshold": 0.5,
38
+ "f1": 0.7524429967426711,
39
+ "precision": 0.9585062240663901,
40
+ "recall": 0.6193029490616622
41
+ },
42
+ "high_recall": {
43
+ "threshold": 0.13265305757522583,
44
+ "f1": 0.5149136577708007,
45
+ "precision": 0.36403995560488345,
46
+ "recall": 0.8793565683646113
47
+ }
48
+ },
49
+ "character": {
50
+ "balanced": {
51
+ "threshold": 0.30408161878585815,
52
+ "f1": 0.769028871391076,
53
+ "precision": 0.8878787878787879,
54
+ "recall": 0.6782407407407407
55
+ },
56
+ "high_precision": {
57
+ "threshold": 0.47551020979881287,
58
+ "f1": 0.7128129602356407,
59
+ "precision": 0.979757085020243,
60
+ "recall": 0.5601851851851852
61
+ },
62
+ "high_recall": {
63
+ "threshold": 0.13265305757522583,
64
+ "f1": 0.5132616487455197,
65
+ "precision": 0.37175493250259606,
66
+ "recall": 0.8287037037037037
67
+ }
68
+ },
69
+ "general": {
70
+ "balanced": {
71
+ "threshold": 0.3285714089870453,
72
+ "f1": 0.6070014256296532,
73
+ "precision": 0.6206003023105161,
74
+ "recall": 0.5939857393820399
75
+ },
76
+ "high_precision": {
77
+ "threshold": 0.47551020979881287,
78
+ "f1": 0.5074963046385584,
79
+ "precision": 0.7958057395143487,
80
+ "recall": 0.3725328097550894
81
+ },
82
+ "high_recall": {
83
+ "threshold": 0.20612245798110962,
84
+ "f1": 0.5094889521485699,
85
+ "precision": 0.3790529978316777,
86
+ "recall": 0.7767903275808619
87
+ }
88
+ },
89
+ "meta": {
90
+ "balanced": {
91
+ "threshold": 0.31224489212036133,
92
+ "f1": 0.5943152454780362,
93
+ "precision": 0.5948275862068966,
94
+ "recall": 0.5938037865748709
95
+ },
96
+ "high_precision": {
97
+ "threshold": 0.41020408272743225,
98
+ "f1": 0.5087924970691676,
99
+ "precision": 0.7977941176470589,
100
+ "recall": 0.37349397590361444
101
+ },
102
+ "high_recall": {
103
+ "threshold": 0.22244898974895477,
104
+ "f1": 0.5037433155080214,
105
+ "precision": 0.365399534522886,
106
+ "recall": 0.810671256454389
107
+ }
108
+ },
109
+ "rating": {
110
+ "balanced": {
111
+ "threshold": 0.34489795565605164,
112
+ "f1": 0.7964912280701754,
113
+ "precision": 0.7229299363057324,
114
+ "recall": 0.88671875
115
+ },
116
+ "high_precision": {
117
+ "threshold": 0.5,
118
+ "f1": 0.6966824644549763,
119
+ "precision": 0.8855421686746988,
120
+ "recall": 0.57421875
121
+ },
122
+ "high_recall": {
123
+ "threshold": 0.10000000149011612,
124
+ "f1": 0.6538952745849297,
125
+ "precision": 0.4857685009487666,
126
+ "recall": 1.0
127
+ }
128
+ },
129
+ "artist": {
130
+ "balanced": {
131
+ "threshold": 0.22244898974895477,
132
+ "f1": 0.5017921146953405,
133
+ "precision": 0.56,
134
+ "recall": 0.45454545454545453
135
+ },
136
+ "high_precision": {
137
+ "threshold": 0.22244898974895477,
138
+ "f1": 0.5017921146953405,
139
+ "precision": 0.56,
140
+ "recall": 0.45454545454545453
141
+ },
142
+ "high_recall": {
143
+ "threshold": 0.22244898974895477,
144
+ "f1": 0.5017921146953405,
145
+ "precision": 0.56,
146
+ "recall": 0.45454545454545453
147
+ }
148
+ },
149
+ "year": {
150
+ "balanced": {
151
+ "threshold": 0.2877551317214966,
152
+ "f1": 0.32867132867132864,
153
+ "precision": 0.2974683544303797,
154
+ "recall": 0.3671875
155
+ },
156
+ "high_precision": {
157
+ "threshold": 0,
158
+ "f1": 0,
159
+ "precision": 0,
160
+ "recall": 0
161
+ },
162
+ "high_recall": {
163
+ "threshold": 0,
164
+ "f1": 0,
165
+ "precision": 0,
166
+ "recall": 0
167
+ }
168
+ }
169
+ }
170
+ }