AngelBottomless commited on
Commit
bc492d2
·
verified ·
1 Parent(s): ca9b012

add REFINED-version export without flash attention

Browse files
Files changed (2) hide show
  1. camie_refined_no_flash.onnx +3 -0
  2. infer-refined.py +129 -0
camie_refined_no_flash.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:415ced374b9387cd438b05438f55a352416b307d8c6160972284f8ea240f9410
3
+ size 1696444276
infer-refined.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ import json
4
+ from PIL import Image
5
+
6
+ def preprocess_image(img_path, target_size=512, keep_aspect=True):
7
+ """
8
+ Load an image from img_path, convert to RGB,
9
+ and resize/pad to (target_size, target_size).
10
+ Scales pixel values to [0,1] and returns a (1,3,target_size,target_size) float32 array.
11
+ """
12
+ img = Image.open(img_path).convert("RGB")
13
+
14
+ if keep_aspect:
15
+ # Preserve aspect ratio, pad black
16
+ w, h = img.size
17
+ aspect = w / h
18
+ if aspect > 1:
19
+ new_w = target_size
20
+ new_h = int(new_w / aspect)
21
+ else:
22
+ new_h = target_size
23
+ new_w = int(new_h * aspect)
24
+
25
+ # Resize with Lanczos
26
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
27
+ # Pad to a square
28
+ background = Image.new("RGB", (target_size, target_size), (0, 0, 0))
29
+ paste_x = (target_size - new_w) // 2
30
+ paste_y = (target_size - new_h) // 2
31
+ background.paste(img, (paste_x, paste_y))
32
+ img = background
33
+ else:
34
+ # simple direct resize to 512x512
35
+ img = img.resize((target_size, target_size), Image.Resampling.LANCZOS)
36
+
37
+ # Convert to numpy array
38
+ arr = np.array(img).astype("float32") / 255.0 # scale to [0,1]
39
+ # Transpose from HWC -> CHW
40
+ arr = np.transpose(arr, (2, 0, 1))
41
+ # Add batch dimension: (1,3,512,512)
42
+ arr = np.expand_dims(arr, axis=0)
43
+ return arr
44
+
45
+ def onnx_inference(img_paths,
46
+ onnx_path="camie_refined_no_flash.onnx",
47
+ threshold=0.325,
48
+ metadata_file="metadata.json"):
49
+ """
50
+ Loads the ONNX model, runs inference on a list of image paths,
51
+ and applies an optional threshold to produce final predictions.
52
+
53
+ Args:
54
+ img_paths: List of paths to images.
55
+ onnx_path: Path to the exported ONNX model file.
56
+ threshold: Probability threshold for deciding if a tag is predicted.
57
+ metadata_file: Path to metadata.json that contains idx_to_tag etc.
58
+
59
+ Returns:
60
+ A list of dicts, each containing:
61
+ {
62
+ "initial_logits": np.ndarray of shape (N_tags,),
63
+ "refined_logits": np.ndarray of shape (N_tags,),
64
+ "predicted_tags": list of tag indices that exceeded threshold,
65
+ ...
66
+ }
67
+ one dict per input image.
68
+ """
69
+ # 1) Initialize ONNX runtime session
70
+ session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
71
+ # Optional: for GPU usage, see if "CUDAExecutionProvider" is available
72
+ # session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
73
+
74
+ # 2) Pre-load metadata
75
+ with open(metadata_file, "r", encoding="utf-8") as f:
76
+ metadata = json.load(f)
77
+ idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
78
+
79
+ # 3) Preprocess each image into a batch
80
+ batch_tensors = []
81
+ for img_path in img_paths:
82
+ x = preprocess_image(img_path, target_size=512, keep_aspect=True)
83
+ batch_tensors.append(x)
84
+ # Concatenate along the batch dimension => shape (batch_size, 3, 512, 512)
85
+ batch_input = np.concatenate(batch_tensors, axis=0)
86
+
87
+ # 4) Run inference
88
+ input_name = session.get_inputs()[0].name # typically "image"
89
+ outputs = session.run(None, {input_name: batch_input})
90
+ # Typically we get [initial_tags, refined_tags] as output
91
+ initial_preds, refined_preds = outputs # shapes => (batch_size, 70527)
92
+
93
+ # 5) For each image in batch, convert logits to predictions if desired
94
+ batch_results = []
95
+ for i in range(initial_preds.shape[0]):
96
+ # Extract one sample's logits
97
+ init_logit = initial_preds[i, :] # shape (N_tags,)
98
+ ref_logit = refined_preds[i, :] # shape (N_tags,)
99
+
100
+ # Convert to probabilities with sigmoid
101
+ ref_prob = 1.0 / (1.0 + np.exp(-ref_logit))
102
+
103
+ # Threshold
104
+ pred_indices = np.where(ref_prob >= threshold)[0]
105
+
106
+ # Build result for this image
107
+ result_dict = {
108
+ "initial_logits": init_logit,
109
+ "refined_logits": ref_logit,
110
+ "predicted_indices": pred_indices,
111
+ "predicted_tags": [idx_to_tag[str(idx)] for idx in pred_indices] # map index->tag name
112
+ }
113
+ batch_results.append(result_dict)
114
+
115
+ return batch_results
116
+
117
+ if __name__ == "__main__":
118
+ # Example usage
119
+ images = ["image1.jpg", "image2.jpg", "image3.jpg"]
120
+ results = onnx_inference(images,
121
+ onnx_path="camie_refined_no_flash.onnx",
122
+ threshold=0.325,
123
+ metadata_file="metadata.json")
124
+
125
+ for i, res in enumerate(results):
126
+ print(f"Image: {images[i]}")
127
+ print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
128
+ print(f" Some predicted tags: {res['predicted_tags'][:10]} (Show up to 10)")
129
+ print()