Spaces:
Running
on
Zero
Running
on
Zero
Update groundingdino/util/inference.py
Browse files- groundingdino/util/inference.py +21 -58
groundingdino/util/inference.py
CHANGED
@@ -5,7 +5,6 @@ import numpy as np
|
|
5 |
import supervision as sv
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
-
from torchvision.ops import box_convert
|
9 |
import bisect
|
10 |
|
11 |
import groundingdino.datasets.transforms as T
|
@@ -14,6 +13,19 @@ from groundingdino.util.misc import clean_state_dict
|
|
14 |
from groundingdino.util.slconfig import SLConfig
|
15 |
from groundingdino.util.utils import get_phrases_from_posmap
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# ----------------------------------------------------------------------------------------------------------------------
|
18 |
# OLD API
|
19 |
# ----------------------------------------------------------------------------------------------------------------------
|
@@ -67,16 +79,16 @@ def predict(
|
|
67 |
with torch.no_grad():
|
68 |
outputs = model(image[None], captions=[caption])
|
69 |
|
70 |
-
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]
|
71 |
-
prediction_boxes = outputs["pred_boxes"].cpu()[0]
|
72 |
|
73 |
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
74 |
-
logits = prediction_logits[mask]
|
75 |
-
boxes = prediction_boxes[mask]
|
76 |
|
77 |
tokenizer = model.tokenizer
|
78 |
tokenized = tokenizer(caption)
|
79 |
-
|
80 |
if remove_combined:
|
81 |
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
82 |
|
@@ -98,21 +110,9 @@ def predict(
|
|
98 |
|
99 |
|
100 |
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
101 |
-
"""
|
102 |
-
This function annotates an image with bounding boxes and labels.
|
103 |
-
|
104 |
-
Parameters:
|
105 |
-
image_source (np.ndarray): The source image to be annotated.
|
106 |
-
boxes (torch.Tensor): A tensor containing bounding box coordinates.
|
107 |
-
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
108 |
-
phrases (List[str]): A list of labels for each bounding box.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
np.ndarray: The annotated image.
|
112 |
-
"""
|
113 |
h, w, _ = image_source.shape
|
114 |
boxes = boxes * torch.Tensor([w, h, w, h])
|
115 |
-
xyxy =
|
116 |
detections = sv.Detections(xyxy=xyxy)
|
117 |
|
118 |
labels = [
|
@@ -156,24 +156,6 @@ class Model:
|
|
156 |
box_threshold: float = 0.35,
|
157 |
text_threshold: float = 0.25
|
158 |
) -> Tuple[sv.Detections, List[str]]:
|
159 |
-
"""
|
160 |
-
import cv2
|
161 |
-
|
162 |
-
image = cv2.imread(IMAGE_PATH)
|
163 |
-
|
164 |
-
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
165 |
-
detections, labels = model.predict_with_caption(
|
166 |
-
image=image,
|
167 |
-
caption=caption,
|
168 |
-
box_threshold=BOX_THRESHOLD,
|
169 |
-
text_threshold=TEXT_THRESHOLD
|
170 |
-
)
|
171 |
-
|
172 |
-
import supervision as sv
|
173 |
-
|
174 |
-
box_annotator = sv.BoxAnnotator()
|
175 |
-
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
|
176 |
-
"""
|
177 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
178 |
boxes, logits, phrases = predict(
|
179 |
model=self.model,
|
@@ -197,25 +179,6 @@ class Model:
|
|
197 |
box_threshold: float,
|
198 |
text_threshold: float
|
199 |
) -> sv.Detections:
|
200 |
-
"""
|
201 |
-
import cv2
|
202 |
-
|
203 |
-
image = cv2.imread(IMAGE_PATH)
|
204 |
-
|
205 |
-
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
|
206 |
-
detections = model.predict_with_classes(
|
207 |
-
image=image,
|
208 |
-
classes=CLASSES,
|
209 |
-
box_threshold=BOX_THRESHOLD,
|
210 |
-
text_threshold=TEXT_THRESHOLD
|
211 |
-
)
|
212 |
-
|
213 |
-
|
214 |
-
import supervision as sv
|
215 |
-
|
216 |
-
box_annotator = sv.BoxAnnotator()
|
217 |
-
annotated_image = box_annotator.annotate(scene=image, detections=detections)
|
218 |
-
"""
|
219 |
caption = ". ".join(classes)
|
220 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
221 |
boxes, logits, phrases = predict(
|
@@ -256,7 +219,7 @@ class Model:
|
|
256 |
logits: torch.Tensor
|
257 |
) -> sv.Detections:
|
258 |
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
259 |
-
xyxy =
|
260 |
confidence = logits.numpy()
|
261 |
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
262 |
|
@@ -270,4 +233,4 @@ class Model:
|
|
270 |
break
|
271 |
else:
|
272 |
class_ids.append(None)
|
273 |
-
return np.array(class_ids)
|
|
|
5 |
import supervision as sv
|
6 |
import torch
|
7 |
from PIL import Image
|
|
|
8 |
import bisect
|
9 |
|
10 |
import groundingdino.datasets.transforms as T
|
|
|
13 |
from groundingdino.util.slconfig import SLConfig
|
14 |
from groundingdino.util.utils import get_phrases_from_posmap
|
15 |
|
16 |
+
|
17 |
+
def cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
|
18 |
+
"""
|
19 |
+
Convert bounding boxes from [cx, cy, w, h] format to [x1, y1, x2, y2] format.
|
20 |
+
"""
|
21 |
+
cx, cy, w, h = boxes.unbind(-1)
|
22 |
+
x1 = cx - 0.5 * w
|
23 |
+
y1 = cy - 0.5 * h
|
24 |
+
x2 = cx + 0.5 * w
|
25 |
+
y2 = cy + 0.5 * h
|
26 |
+
return torch.stack((x1, y1, x2, y2), dim=-1)
|
27 |
+
|
28 |
+
|
29 |
# ----------------------------------------------------------------------------------------------------------------------
|
30 |
# OLD API
|
31 |
# ----------------------------------------------------------------------------------------------------------------------
|
|
|
79 |
with torch.no_grad():
|
80 |
outputs = model(image[None], captions=[caption])
|
81 |
|
82 |
+
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]
|
83 |
+
prediction_boxes = outputs["pred_boxes"].cpu()[0]
|
84 |
|
85 |
mask = prediction_logits.max(dim=1)[0] > box_threshold
|
86 |
+
logits = prediction_logits[mask]
|
87 |
+
boxes = prediction_boxes[mask]
|
88 |
|
89 |
tokenizer = model.tokenizer
|
90 |
tokenized = tokenizer(caption)
|
91 |
+
|
92 |
if remove_combined:
|
93 |
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
94 |
|
|
|
110 |
|
111 |
|
112 |
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
h, w, _ = image_source.shape
|
114 |
boxes = boxes * torch.Tensor([w, h, w, h])
|
115 |
+
xyxy = cxcywh_to_xyxy(boxes).numpy()
|
116 |
detections = sv.Detections(xyxy=xyxy)
|
117 |
|
118 |
labels = [
|
|
|
156 |
box_threshold: float = 0.35,
|
157 |
text_threshold: float = 0.25
|
158 |
) -> Tuple[sv.Detections, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
160 |
boxes, logits, phrases = predict(
|
161 |
model=self.model,
|
|
|
179 |
box_threshold: float,
|
180 |
text_threshold: float
|
181 |
) -> sv.Detections:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
caption = ". ".join(classes)
|
183 |
processed_image = Model.preprocess_image(image_bgr=image).to(self.device)
|
184 |
boxes, logits, phrases = predict(
|
|
|
219 |
logits: torch.Tensor
|
220 |
) -> sv.Detections:
|
221 |
boxes = boxes * torch.Tensor([source_w, source_h, source_w, source_h])
|
222 |
+
xyxy = cxcywh_to_xyxy(boxes).numpy()
|
223 |
confidence = logits.numpy()
|
224 |
return sv.Detections(xyxy=xyxy, confidence=confidence)
|
225 |
|
|
|
233 |
break
|
234 |
else:
|
235 |
class_ids.append(None)
|
236 |
+
return np.array(class_ids)
|