diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 3070764e06..0f1ccb6ec3 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -38,6 +38,7 @@ binary_erosion, _ = optional_import("scipy.ndimage", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt") distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") +KDTree, _ = optional_import("scipy.spatial", name="KDTree") scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage") cupy, has_cupy = optional_import("cupy") @@ -269,7 +270,8 @@ def get_surface_distance( distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] the metric used to compute surface distance. Defaults to ``"euclidean"``. - - ``"euclidean"``, uses Exact Euclidean distance transform. + - ``"euclidean"``, the exact Euclidean distance (a KD-tree over the edge voxels on + CPU, or the cuCIM distance transform when the inputs are on a CUDA device). - ``"chessboard"``, uses `chessboard` metric in chamfer type of transform. - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``. @@ -291,6 +293,26 @@ def get_surface_distance( dis = dis[seg_gt] return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0] if distance_metric == "euclidean": + # The euclidean surface distance only needs the distance from each `seg_pred` + # edge voxel to the nearest `seg_gt` edge voxel. CPU and GPU favour different + # algorithms for this: + # * On CPU, a KD-tree over the (sparse) edge-voxel coordinates avoids the dense + # full-volume distance transform, and handles outlier points that expand the + # bounding box. + # * On GPU, the dense EDT is embarrassingly parallel and significantly faster than + # cupy's KDTree (as of this writing anyway) + on_gpu = isinstance(seg_gt, torch.Tensor) and seg_gt.device.type == "cuda" + if not on_gpu: + gt_coords = np.argwhere(convert_to_numpy(seg_gt)).astype(np.float64) + pred_coords = np.argwhere(convert_to_numpy(seg_pred)).astype(np.float64) + if spacing is not None: + scale = np.asarray(spacing, dtype=np.float64) + gt_coords *= scale + pred_coords *= scale + # leafsize larger than the default (16) is faster here: we build the tree + # for a single batched query rather than amortizing it over many queries. + surface_distance = KDTree(gt_coords, leafsize=32).query(pred_coords, k=1)[0] + return convert_to_dst_type(surface_distance, seg_pred, dtype=lib.float32)[0] dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0] # type: ignore elif distance_metric in {"chessboard", "taxicab"}: dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric) diff --git a/tests/metrics/test_surface_distance.py b/tests/metrics/test_surface_distance.py index 85db389f80..e43d7d5955 100644 --- a/tests/metrics/test_surface_distance.py +++ b/tests/metrics/test_surface_distance.py @@ -16,8 +16,10 @@ import numpy as np import torch from parameterized import parameterized +from scipy.ndimage import distance_transform_edt from monai.metrics import SurfaceDistanceMetric +from monai.metrics.utils import get_mask_edges, get_surface_distance _device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -182,5 +184,43 @@ def test_nans(self, input_data): np.testing.assert_allclose(0, not_nans, rtol=1e-5) +KDTREE_SPACINGS = [["isotropic_default", None], ["isotropic", (1.0, 1.0, 1.0)], ["anisotropic", (1.0, 2.5, 0.5)]] + + +def _edge_masks(seed=0): + # two offset spheres plus a few scattered false positives in the prediction, so the + # surfaces are non-trivially apart and an outlier expands the cropped bounding box. + gt = create_spherical_seg_3d(radius=20, centre=(30, 30, 30)) + pred = create_spherical_seg_3d(radius=20, centre=(32, 31, 30)) + rng = np.random.RandomState(seed) + for _ in range(5): + pred[tuple(rng.randint(0, s) for s in pred.shape)] = 1 + edges_pred, edges_gt = get_mask_edges(pred, gt) + return np.asarray(edges_pred, dtype=bool), np.asarray(edges_gt, dtype=bool) + + +class TestSurfaceDistanceKDTreeMatchesEDT(unittest.TestCase): + @parameterized.expand(KDTREE_SPACINGS) + def test_cpu_kdtree_euclidean_distances_match_dense_edt(self, _name, spacing): + edges_pred, edges_gt = _edge_masks() + result = np.asarray(get_surface_distance(edges_pred, edges_gt, distance_metric="euclidean", spacing=spacing)) + reference = distance_transform_edt(~edges_gt, sampling=spacing)[edges_pred] + # same multiset of distances (downstream metrics only use max/percentile/mean) + np.testing.assert_allclose(np.sort(result), np.sort(reference), rtol=1e-5, atol=1e-5) + self.assertEqual(result.dtype, np.float32) + self.assertEqual(result.shape, reference.shape) + + def test_torch_input_preserves_type_device_and_matches_dense_edt(self): + edges_pred, edges_gt = _edge_masks() + spacing = (1.0, 2.5, 0.5) + seg_pred, seg_gt = torch.as_tensor(edges_pred), torch.as_tensor(edges_gt) + result = get_surface_distance(seg_pred, seg_gt, distance_metric="euclidean", spacing=spacing) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.dtype, torch.float32) + self.assertEqual(result.device, seg_pred.device) + reference = distance_transform_edt(~edges_gt, sampling=spacing)[edges_pred] + np.testing.assert_allclose(np.sort(result.cpu().numpy()), np.sort(reference), rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": unittest.main()