zhengrongzhang wangfangyuan commited on
Commit
3551260
1 Parent(s): e6c79f4

Upload Code&Model for NHWC format (#2)

Browse files

- Upload 10 files (8a142b795d68087dd3b650e755a5fe10ecbd1ff9)


Co-authored-by: fangyuan wang <[email protected]>

Files changed (3) hide show
  1. onnx_eval.py +4 -2
  2. onnx_inference.py +4 -2
  3. yolov8m_qat.onnx +2 -2
onnx_eval.py CHANGED
@@ -78,8 +78,10 @@ class DetectionValidator:
78
  batch = self.preprocess(batch)
79
 
80
  # inference
81
- outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
82
- outputs = [torch.tensor(item).to(self.device) for item in outputs]
 
 
83
  preds = post_process(outputs)
84
 
85
  # pre-process predictions
 
78
  batch = self.preprocess(batch)
79
 
80
  # inference
81
+ # outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].cpu().numpy()})
82
+ outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: batch["img"].permute(0, 2, 3, 1).cpu().numpy()})
83
+ # outputs = [torch.tensor(item).to(self.device) for item in outputs]
84
+ outputs = [torch.tensor(item).permute(0, 3, 1, 2).to(self.device) for item in outputs]
85
  preds = post_process(outputs)
86
 
87
  # pre-process predictions
onnx_inference.py CHANGED
@@ -133,8 +133,10 @@ if __name__ == '__main__':
133
  im = preprocess(im)
134
  if len(im.shape) == 3:
135
  im = im[None]
136
- outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
137
- outputs = [torch.tensor(item) for item in outputs]
 
 
138
  preds = post_process(outputs)
139
  preds = non_max_suppression(
140
  preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
 
133
  im = preprocess(im)
134
  if len(im.shape) == 3:
135
  im = im[None]
136
+ # outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.cpu().numpy()})
137
+ # outputs = [torch.tensor(item) for item in outputs]
138
+ outputs = onnx_model.run(None, {onnx_model.get_inputs()[0].name: im.permute(0, 2, 3, 1).cpu().numpy()})
139
+ outputs = [torch.tensor(item).permute(0, 3, 1, 2) for item in outputs]
140
  preds = post_process(outputs)
141
  preds = non_max_suppression(
142
  preds, 0.25, 0.7, agnostic=False, max_det=300, classes=None
yolov8m_qat.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3b770e88b358ad24cc60e7b8bbc00b09bb1e0308f65f45cdcea2a1dfc1301077
3
- size 103874610
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:119038397368b01fee9ad8adcc62061babcf2e2dd417be1946d5bfccb07eb65f
3
+ size 103874987