Commit
•
3551260
1
Parent(s):
e6c79f4
Upload Code&Model for NHWC format (#2)
Browse files- Upload 10 files (8a142b795d68087dd3b650e755a5fe10ecbd1ff9)
Co-authored-by: fangyuan wang <[email protected]>
- onnx_eval.py +4 -2
- onnx_inference.py +4 -2
- yolov8m_qat.onnx +2 -2
onnx_eval.py
CHANGED
@@ -78,8 +78,10 @@ class DetectionValidator:
|
|
78 |
batch = self.preprocess(batch)
|
79 |
|
80 |
# inference
|
81 |
-
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
|
82 |
-
outputs =
|
|
|
|
|
83 |
preds = post_process(outputs)
|
84 |
|
85 |
# pre-process predictions
|
|
|
78 |
batch = self.preprocess(batch)
|
79 |
|
80 |
# inference
|
81 |
+
# outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
|
82 |
+
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].permute(0, 2, 3, 1).cpu().numpy()})
|
83 |
+
# outputs = [torch.tensor(item).to(self.device) for item in outputs]
|
84 |
+
outputs = [torch.tensor(item).permute(0, 3, 1, 2).to(self.device) for item in outputs]
|
85 |
preds = post_process(outputs)
|
86 |
|
87 |
# pre-process predictions
|
onnx_inference.py
CHANGED
@@ -133,8 +133,10 @@ if __name__ == '__main__':
|
|
133 |
im = preprocess(im)
|
134 |
if len(im.shape) == 3:
|
135 |
im = im[None]
|
136 |
-
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
|
137 |
-
outputs = [torch.tensor(item) for item in outputs]
|
|
|
|
|
138 |
preds = post_process(outputs)
|
139 |
preds = non_max_suppression(
|
140 |
preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
|
|
|
133 |
im = preprocess(im)
|
134 |
if len(im.shape) == 3:
|
135 |
im = im[None]
|
136 |
+
# outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
|
137 |
+
# outputs = [torch.tensor(item) for item in outputs]
|
138 |
+
outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.permute(0, 2, 3, 1).cpu().numpy()})
|
139 |
+
outputs = [torch.tensor(item).permute(0, 3, 1, 2) for item in outputs]
|
140 |
preds = post_process(outputs)
|
141 |
preds = non_max_suppression(
|
142 |
preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
|
yolov8m_qat.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:119038397368b01fee9ad8adcc62061babcf2e2dd417be1946d5bfccb07eb65f
|
3 |
+
size 103874987
|