Mountchicken commited on
Commit
f6a0151
·
verified ·
1 Parent(s): 54888ff

Update groundingdino/util/inference.py

Browse files
Files changed (1) hide show
  1. groundingdino/util/inference.py +21 -58
groundingdino/util/inference.py CHANGED
@@ -5,7 +5,6 @@ import numpy as np
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
8
- from torchvision.ops import box_convert
9
  import bisect
10
 
11
  import groundingdino.datasets.transforms as T
@@ -14,6 +13,19 @@ from groundingdino.util.misc import clean_state_dict
14
  from groundingdino.util.slconfig import SLConfig
15
  from groundingdino.util.utils import get_phrases_from_posmap
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ----------------------------------------------------------------------------------------------------------------------
18
  # OLD API
19
  # ----------------------------------------------------------------------------------------------------------------------
@@ -67,16 +79,16 @@ def predict(
67
  with torch.no_grad():
68
  outputs = model(image[None], captions=[caption])
69
 
70
- prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
71
- prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
72
 
73
  mask = prediction_logits.max(dim=1)[0] > box_threshold
74
- logits = prediction_logits[mask] # logits.shape = (n, 256)
75
- boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
76
 
77
  tokenizer = model.tokenizer
78
  tokenized = tokenizer(caption)
79
-
80
  if remove_combined:
81
  sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
82
 
@@ -98,21 +110,9 @@ def predict(
98
 
99
 
100
  def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
101
- """
102
- This function annotates an image with bounding boxes and labels.
103
-
104
- Parameters:
105
- image_source (np.ndarray): The source image to be annotated.
106
- boxes (torch.Tensor): A tensor containing bounding box coordinates.
107
- logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
108
- phrases (List[str]): A list of labels for each bounding box.
109
-
110
- Returns:
111
- np.ndarray: The annotated image.
112
- """
113
  h, w, _ = image_source.shape
114
  boxes = boxes * torch.Tensor([w, h, w, h])
115
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
116
  detections = sv.Detections(xyxy=xyxy)
117
 
118
  labels = [
@@ -156,24 +156,6 @@ class Model:
156
  box_threshold: float = 0.35,
157
  text_threshold: float = 0.25
158
  ) -> Tuple[sv.Detections, List[str]]:
159
- """
160
- import cv2
161
-
162
- image = cv2.imread(IMAGE_PATH)
163
-
164
- model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
165
- detections, labels = model.predict_with_caption(
166
- image=image,
167
- caption=caption,
168
- box_threshold=BOX_THRESHOLD,
169
- text_threshold=TEXT_THRESHOLD
170
- )
171
-
172
- import supervision as sv
173
-
174
- box_annotator = sv.BoxAnnotator()
175
- annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
176
- """
177
  processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
178
  boxes, logits, phrases = predict(
179
  model=self.model,
@@ -197,25 +179,6 @@ class Model:
197
  box_threshold: float,
198
  text_threshold: float
199
  ) -> sv.Detections:
200
- """
201
- import cv2
202
-
203
- image = cv2.imread(IMAGE_PATH)
204
-
205
- model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
206
- detections = model.predict_with_classes(
207
- image=image,
208
- classes=CLASSES,
209
- box_threshold=BOX_THRESHOLD,
210
- text_threshold=TEXT_THRESHOLD
211
- )
212
-
213
-
214
- import supervision as sv
215
-
216
- box_annotator = sv.BoxAnnotator()
217
- annotated_image = box_annotator.annotate(scene=image, detections=detections)
218
- """
219
  caption = ". ".join(classes)
220
  processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
221
  boxes, logits, phrases = predict(
@@ -256,7 +219,7 @@ class Model:
256
  logits: torch.Tensor
257
  ) -> sv.Detections:
258
  boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
259
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
260
  confidence = logits.numpy()
261
  return sv.Detections(xyxy=xyxy, confidence=confidence)
262
 
@@ -270,4 +233,4 @@ class Model:
270
  break
271
  else:
272
  class_ids.append(None)
273
- return np.array(class_ids)
 
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
 
8
  import bisect
9
 
10
  import groundingdino.datasets.transforms as T
 
13
  from groundingdino.util.slconfig import SLConfig
14
  from groundingdino.util.utils import get_phrases_from_posmap
15
 
16
+
17
+ def cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
18
+ """
19
+ Convert bounding boxes from [cx, cy, w, h] format to [x1, y1, x2, y2] format.
20
+ """
21
+ cx, cy, w, h = boxes.unbind(-1)
22
+ x1 = cx - 0.5 * w
23
+ y1 = cy - 0.5 * h
24
+ x2 = cx + 0.5 * w
25
+ y2 = cy + 0.5 * h
26
+ return torch.stack((x1, y1, x2, y2), dim=-1)
27
+
28
+
29
  # ----------------------------------------------------------------------------------------------------------------------
30
  # OLD API
31
  # ----------------------------------------------------------------------------------------------------------------------
 
79
  with torch.no_grad():
80
  outputs = model(image[None], captions=[caption])
81
 
82
+ prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]
83
+ prediction_boxes = outputs["pred_boxes"].cpu()[0]
84
 
85
  mask = prediction_logits.max(dim=1)[0] > box_threshold
86
+ logits = prediction_logits[mask]
87
+ boxes = prediction_boxes[mask]
88
 
89
  tokenizer = model.tokenizer
90
  tokenized = tokenizer(caption)
91
+
92
  if remove_combined:
93
  sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
94
 
 
110
 
111
 
112
  def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
113
  h, w, _ = image_source.shape
114
  boxes = boxes * torch.Tensor([w, h, w, h])
115
+ xyxy = cxcywh_to_xyxy(boxes).numpy()
116
  detections = sv.Detections(xyxy=xyxy)
117
 
118
  labels = [
 
156
  box_threshold: float = 0.35,
157
  text_threshold: float = 0.25
158
  ) -> Tuple[sv.Detections, List[str]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
160
  boxes, logits, phrases = predict(
161
  model=self.model,
 
179
  box_threshold: float,
180
  text_threshold: float
181
  ) -> sv.Detections:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  caption = ". ".join(classes)
183
  processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
184
  boxes, logits, phrases = predict(
 
219
  logits: torch.Tensor
220
  ) -> sv.Detections:
221
  boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
222
+ xyxy = cxcywh_to_xyxy(boxes).numpy()
223
  confidence = logits.numpy()
224
  return sv.Detections(xyxy=xyxy, confidence=confidence)
225
 
 
233
  break
234
  else:
235
  class_ids.append(None)
236
+ return np.array(class_ids)