diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index b0c41b5d7d..2f4a1426a9 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -1200,7 +1200,7 @@ def batched_nms( # from different classes do not overlap max_coordinate = boxes_t.max() offsets = labels_t.to(boxes_t) * (max_coordinate + 1) - boxes_for_nms = boxes + offsets[:, None] + boxes_for_nms = boxes_t + offsets[:, None] keep = non_max_suppression(boxes_for_nms, scores_t, nms_thresh, max_proposals, box_overlap_metric) # convert tensor back to numpy if needed diff --git a/tests/data/test_box_utils.py b/tests/data/test_box_utils.py index 05778f691b..30136d4f1b 100644 --- a/tests/data/test_box_utils.py +++ b/tests/data/test_box_utils.py @@ -23,6 +23,7 @@ CornerCornerModeTypeB, CornerCornerModeTypeC, CornerSizeMode, + batched_nms, box_area, box_centers, box_giou, @@ -269,5 +270,15 @@ def test_integer_truncation_bug(self): self.assertGreater(iou[0, 0], 0.0, "IoU should not be truncated to 0") +class TestBatchedNms(unittest.TestCase): + @parameterized.expand(TEST_NDARRAYS) + def test_batched_nms_backend(self, p): + boxes = p(np.array([[0, 0, 10, 10], [1, 1, 11, 11], [100, 100, 110, 110]], dtype=np.float32)) + scores = p(np.array([0.9, 0.8, 0.7], dtype=np.float32)) + labels = p(np.array([0, 0, 1])) + keep = batched_nms(boxes, scores, labels, nms_thresh=0.5) + assert_allclose(keep, [0, 2], type_test=False) + + if __name__ == "__main__": unittest.main()