Skip to content

Commit 341300d

Browse files
committed
fix OOB bugs in yolov5/7 and rfdetr
1 parent 1ededd4 commit 341300d

File tree

11 files changed

+57
-49
lines changed

11 files changed

+57
-49
lines changed

inference_models/inference_models/models/common/roboflow/post_processing.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@ def run_nms_for_object_detection(
2525
boxes = output[:, :4, :]
2626
scores = output[:, 4:, :]
2727
results = []
28-
per_class_thresh = (
29-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
30-
)
3128
for b in range(bs):
3229
class_scores = scores[b]
3330
class_conf, class_ids = class_scores.max(0)
34-
if per_class_thresh is not None:
35-
mask = class_conf > per_class_thresh[class_ids]
31+
if isinstance(conf_thresh, torch.Tensor):
32+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
3633
else:
3734
mask = class_conf > conf_thresh
3835
if not torch.any(mask):
@@ -74,15 +71,12 @@ def post_process_nms_fused_model_output(
7471
) -> List[torch.Tensor]:
7572
bs = output.shape[0]
7673
nms_results = []
77-
per_class_thresh = (
78-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
79-
)
8074
for batch_element_id in range(bs):
8175
batch_element_result = output[batch_element_id]
82-
if per_class_thresh is not None:
76+
if isinstance(conf_thresh, torch.Tensor):
8377
class_ids = batch_element_result[:, 5].long()
8478
batch_element_result = batch_element_result[
85-
batch_element_result[:, 4] >= per_class_thresh[class_ids]
79+
batch_element_result[:, 4] >= conf_thresh.to(output.device)[class_ids]
8680
]
8781
else:
8882
batch_element_result = batch_element_result[
@@ -105,17 +99,14 @@ def run_nms_for_instance_segmentation(
10599
scores = output[:, 4:-32, :] # (N, 80, 8400)
106100
masks = output[:, -32:, :]
107101
results = []
108-
per_class_thresh = (
109-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
110-
)
111102

112103
for b in range(bs):
113104
bboxes = boxes[b].T # (8400, 4)
114105
class_scores = scores[b].T # (8400, 80)
115106
box_masks = masks[b].T
116107
class_conf, class_ids = class_scores.max(1) # (8400,), (8400,)
117-
if per_class_thresh is not None:
118-
mask = class_conf > per_class_thresh[class_ids]
108+
if isinstance(conf_thresh, torch.Tensor):
109+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
119110
else:
120111
mask = class_conf > conf_thresh
121112
if mask.sum() == 0:
@@ -164,14 +155,11 @@ def run_nms_for_key_points_detection(
164155
scores = output[:, 4 : 4 + num_classes, :]
165156
key_points = output[:, 4 + num_classes :, :]
166157
results = []
167-
per_class_thresh = (
168-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
169-
)
170158
for b in range(bs):
171159
class_scores = scores[b]
172160
class_conf, class_ids = class_scores.max(0)
173-
if per_class_thresh is not None:
174-
mask = class_conf > per_class_thresh[class_ids]
161+
if isinstance(conf_thresh, torch.Tensor):
162+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
175163
else:
176164
mask = class_conf > conf_thresh
177165
if not torch.any(mask):

inference_models/inference_models/models/rfdetr/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def post_process_instance_segmentation_results(
6666
confidence = confidence[remapping_mask]
6767
image_bboxes = image_bboxes[remapping_mask]
6868
image_masks = image_masks[remapping_mask]
69+
else:
70+
# drop DETR no-object rows
71+
named = top_classes < threshold.shape[0]
72+
confidence = confidence[named]
73+
top_classes = top_classes[named]
74+
image_bboxes = image_bboxes[named]
75+
image_masks = image_masks[named]
6976
confidence_mask = confidence > threshold[top_classes.long()]
7077
confidence = confidence[confidence_mask]
7178
top_classes = top_classes[confidence_mask]

inference_models/inference_models/models/rfdetr/rfdetr_object_detection_onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ def post_process(
238238
]
239239
predicted_confidence = predicted_confidence[remapping_mask]
240240
image_bboxes = image_bboxes[remapping_mask]
241+
else:
242+
# drop DETR no-object rows
243+
named = top_classes < len(self.class_names)
244+
predicted_confidence = predicted_confidence[named]
245+
top_classes = top_classes[named]
246+
image_bboxes = image_bboxes[named]
241247
confidence_mask = predicted_confidence > thresholds[top_classes.long()]
242248
predicted_confidence = predicted_confidence[confidence_mask]
243249
top_classes = top_classes[confidence_mask]

inference_models/inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,12 @@ def post_process(
493493
labels = self._classes_re_mapping.class_mapping[labels[remapping_mask]]
494494
boxes = boxes[remapping_mask]
495495
score_thresholds = thresholds.to(dtype=scores.dtype)
496+
if self._classes_re_mapping is None:
497+
# drop DETR no-object rows
498+
named = labels < score_thresholds.shape[0]
499+
scores = scores[named]
500+
labels = labels[named]
501+
boxes = boxes[named]
496502
keep = scores > score_thresholds[labels.long()]
497503
scores = scores[keep]
498504
labels = labels[keep]

inference_models/inference_models/models/rfdetr/rfdetr_object_detection_trt.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ def post_process(
295295
]
296296
predicted_confidence = predicted_confidence[remapping_mask]
297297
image_bboxes = image_bboxes[remapping_mask]
298+
else:
299+
# drop DETR no-object rows
300+
named = top_classes < len(self.class_names)
301+
predicted_confidence = predicted_confidence[named]
302+
top_classes = top_classes[named]
303+
image_bboxes = image_bboxes[named]
298304
confidence_mask = predicted_confidence > thresholds[top_classes.long()]
299305
predicted_confidence = predicted_confidence[confidence_mask]
300306
top_classes = top_classes[confidence_mask]

inference_models/inference_models/models/yolact/yolact_instance_segmentation_onnx.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,13 @@ def run_nms_for_instance_segmentation(
338338
scores = output[:, :, 4:-32] # (N, 19248, num_classes)
339339
masks = output[:, :, -32:]
340340
results = []
341-
per_class_thresh = (
342-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
343-
)
344341
for b in range(bs):
345342
bboxes = boxes[b] # (19248, 4)
346343
class_scores = scores[b] # (19248, 80)
347344
box_masks = masks[b]
348345
class_conf, class_ids = class_scores.max(1) # (8400,), (8400,)
349-
if per_class_thresh is not None:
350-
mask = class_conf > per_class_thresh[class_ids]
346+
if isinstance(conf_thresh, torch.Tensor):
347+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
351348
else:
352349
mask = class_conf > conf_thresh
353350
if mask.sum() == 0:

inference_models/inference_models/models/yolact/yolact_instance_segmentation_trt.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,16 +422,13 @@ def run_nms_for_instance_segmentation(
422422
scores = output[:, :, 4:-32] # (N, 19248, num_classes)
423423
masks = output[:, :, -32:]
424424
results = []
425-
per_class_thresh = (
426-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
427-
)
428425
for b in range(bs):
429426
bboxes = boxes[b] # (19248, 4)
430427
class_scores = scores[b] # (19248, 80)
431428
box_masks = masks[b]
432429
class_conf, class_ids = class_scores.max(1) # (8400,), (8400,)
433-
if per_class_thresh is not None:
434-
mask = class_conf > per_class_thresh[class_ids]
430+
if isinstance(conf_thresh, torch.Tensor):
431+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
435432
else:
436433
mask = class_conf > conf_thresh
437434
if mask.sum() == 0:

inference_models/inference_models/models/yolonas/nms.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@ def run_yolonas_nms_for_object_detection(
1515
boxes = output[:, :, :4]
1616
scores = output[:, :, 4:]
1717
results = []
18-
per_class_thresh = (
19-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
20-
)
2118
for b in range(bs):
2219
class_scores = scores[b] # (8400, cls_num)
2320
class_conf, class_ids = torch.max(class_scores, dim=-1)
24-
if per_class_thresh is not None:
25-
mask = class_conf > per_class_thresh[class_ids]
21+
if isinstance(conf_thresh, torch.Tensor):
22+
mask = class_conf > conf_thresh.to(output.device)[class_ids]
2623
else:
2724
mask = class_conf > conf_thresh
2825
if not torch.any(mask):

inference_models/inference_models/models/yolov5/nms.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@ def run_nms_yolov5(
1616
top_classes_conf = output[:, 4, :]
1717
scores = output[:, 5:, :]
1818
results = []
19-
per_class_thresh = (
20-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
21-
)
2219
for b in range(bs):
2320
class_scores = scores[b]
2421
class_conf, class_ids = class_scores.max(0)
25-
if per_class_thresh is not None:
26-
mask = top_classes_conf[b] > per_class_thresh[class_ids]
22+
if isinstance(conf_thresh, torch.Tensor):
23+
mask = top_classes_conf[b] > conf_thresh.to(output.device)[class_ids]
2724
else:
2825
mask = top_classes_conf[b] > conf_thresh
2926
if not torch.any(mask):
@@ -68,17 +65,20 @@ def run_yolov5_nms_for_instance_segmentation(
6865
scores = output[:, 4:-32, :]
6966
masks = output[:, -32:, :]
7067
results = []
71-
per_class_thresh = (
72-
conf_thresh.to(output.device) if isinstance(conf_thresh, torch.Tensor) else None
73-
)
7468

7569
for b in range(bs):
7670
bboxes = boxes[b].T
7771
class_scores = scores[b].T
7872
box_masks = masks[b].T
7973
class_conf, class_ids = class_scores.max(1)
80-
if per_class_thresh is not None:
81-
mask = top_classes_conf[b] > per_class_thresh[class_ids]
74+
if isinstance(conf_thresh, torch.Tensor):
75+
# class_ids are slice-indexed: 0 is objectness, k>=1 is class_{k-1}.
76+
# Keep objectness-dominated rows (no real class to threshold) and
77+
# filter class-dominated rows against the real class's threshold.
78+
thresh = conf_thresh.to(output.device)
79+
real_class = (class_ids - 1).clamp(min=0, max=thresh.shape[0] - 1)
80+
is_obj = class_ids == 0
81+
mask = is_obj | (top_classes_conf[b] > thresh[real_class])
8282
else:
8383
mask = top_classes_conf[b] > conf_thresh
8484
if mask.sum() == 0:

inference_models/inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
align_instance_segmentation_results,
3939
crop_masks_to_boxes,
4040
preprocess_segmentation_masks,
41-
run_nms_for_instance_segmentation,
41+
)
42+
from inference_models.models.yolov5.nms import (
43+
run_yolov5_nms_for_instance_segmentation,
4244
)
4345
from inference_models.models.common.roboflow.pre_processing import (
4446
pre_process_network_input,
@@ -214,7 +216,7 @@ def post_process(
214216
)
215217
confidence = confidence_filter.per_class_thresholds(self.class_names)
216218
instances, protos = model_results
217-
nms_results = run_nms_for_instance_segmentation(
219+
nms_results = run_yolov5_nms_for_instance_segmentation(
218220
output=instances.permute(0, 2, 1),
219221
conf_thresh=confidence,
220222
iou_thresh=iou_threshold,

0 commit comments

Comments
 (0)