diff --git a/docs/source/installation.md b/docs/source/installation.md index 5123bc3e6b..42163c8d6e 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, rankseg] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `rankseg` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..90d3474f86 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -39,15 +39,19 @@ ) from monai.transforms.utils_pytorch_numpy_unification import unravel_index from monai.utils import ( + OptionalImportError, TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, get_equivalent_dtype, look_up_option, + optional_import, ) from monai.utils.type_conversion import convert_to_dst_type +rankseg_fn, has_rankseg = optional_import("rankseg.functional", name="rankseg") + __all__ = [ "Activations", "AsDiscrete", @@ -142,6 +146,7 @@ class AsDiscrete(Transform): Convert the input tensor/array into discrete values, possible operations are: - `argmax`. + - `rankseg`. - threshold input value to binary values. - convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes). - round the value to the closest integer. @@ -155,9 +160,17 @@ class AsDiscrete(Transform): Defaults to ``None``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + RankSEG is applied to a channel-first probability map for one image; ``dim`` identifies the + class/channel dimension and is moved to the front before decoding. For the common MONAI + post-processing input shape ``(C, *spatial)``, use the default ``dim=0``. + The output is a label map. With the default ``keepdim=True``, the output shape is ``(1, *spatial)``; + with ``keepdim=False``, it is ``(*spatial)``. The ``dim`` and ``keepdim`` shape handling is aligned + with ``argmax``. This option is incompatible with ``argmax=True``. + Defaults to ``False``. kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`. - currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. - These default to ``0``, ``True``, ``torch.float`` respectively. + currently ``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters + will be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively. Example: @@ -173,6 +186,12 @@ class AsDiscrete(Transform): >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) # [[[0.0, 0.0]], [[1.0, 1.0]]] + RankSEG decoding requires the optional ``rankseg`` package: + + >>> transform = AsDiscrete(rankseg=True) + >>> print(transform(np.array([[[0.3, 0.6]], [[0.7, 0.4]]]))) + # [[[1.0, 1.0]]] + """ backend = [TransformBackends.TORCH] @@ -183,9 +202,13 @@ def __init__( to_onehot: int | None = None, threshold: float | None = None, rounding: str | None = None, + rankseg: bool = False, **kwargs, ) -> None: + if argmax and rankseg: + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.argmax = argmax + self.rankseg = rankseg if isinstance(to_onehot, bool): # for backward compatibility raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") self.to_onehot = to_onehot @@ -200,6 +223,7 @@ def __call__( to_onehot: int | None = None, threshold: float | None = None, rounding: str | None = None, + rankseg: bool | None = None, ) -> NdarrayOrTensor: """ Args: @@ -211,6 +235,11 @@ def __call__( Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + Applies RankSEG to a channel-first probability map by default and uses the same ``dim`` and + ``keepdim`` shape handling as ``argmax``. The RankSEG ``metric`` can be specified in ``kwargs``. + This option is incompatible with ``argmax=True``. + Defaults to ``self.rankseg``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. @@ -220,9 +249,26 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) argmax = self.argmax if argmax is None else argmax + rankseg = self.rankseg if rankseg is None else rankseg + + if argmax and rankseg: + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) + if rankseg: + if not has_rankseg: + raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.") + # Adjust shape to meet RankSEG's [B, C, *spatial] input requirement. + channel_dim = self.kwargs.get("dim", 0) % img_t.ndim + keepdim = self.kwargs.get("keepdim", True) + img_t = rankseg_fn( + img_t.movedim(channel_dim, 0).unsqueeze(0), metric=self.kwargs.get("metric", "dice") + ).squeeze(0) + if keepdim: + img_t = img_t.unsqueeze(channel_dim) + to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 65fdd22b22..4bf3a53fb9 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -166,6 +166,7 @@ def __init__( to_onehot: Sequence[int | None] | int | None = None, threshold: Sequence[float | None] | float | None = None, rounding: Sequence[str | None] | str | None = None, + rankseg: Sequence[bool] | bool = False, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -182,14 +183,21 @@ def __init__( rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + RankSEG expects channel-first probability maps for one image. It also can be a sequence of bool, + each element corresponds to a key in ``keys``. Uses the same ``dim`` and ``keepdim`` shape handling + as ``argmax``. This option is incompatible with ``argmax=True``. allow_missing_keys: don't raise exception if key is missing. kwargs: additional parameters to ``AsDiscrete``. - ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. - These default to ``0``, ``True``, ``torch.float`` respectively. + ``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters will + be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively. """ super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) + self.rankseg = ensure_tuple_rep(rankseg, len(self.keys)) + if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg, strict=True)): + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.to_onehot = [] for flag in ensure_tuple_rep(to_onehot, len(self.keys)): if isinstance(flag, bool): @@ -208,10 +216,12 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, threshold, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.threshold, self.rounding + for key, argmax, to_onehot, threshold, rounding, rankseg in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding, self.rankseg ): - d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) + d[key] = self.converter( + d[key], argmax=argmax, to_onehot=to_onehot, threshold=threshold, rounding=rounding, rankseg=rankseg + ) return d diff --git a/requirements-dev.txt b/requirements-dev.txt index eb4429cce7..19e26298f3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,6 +59,7 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 +rankseg>=0.0.5 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy diff --git a/setup.cfg b/setup.cfg index 724d1eceb3..1d16af8728 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,6 +90,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 + rankseg>=0.0.5 nibabel = nibabel ninja = @@ -179,6 +180,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0, <5.3.0 +rankseg = + rankseg>=0.0.5 # segment-anything = # segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index a83870e514..92721642dd 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -12,10 +12,13 @@ from __future__ import annotations import unittest +from unittest import mock from parameterized import parameterized from monai.transforms import AsDiscrete +from monai.transforms.post import array as post_array +from monai.utils import OptionalImportError from tests.test_utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] @@ -63,6 +66,17 @@ [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] ) + TEST_CASES.append( + [{"rankseg": False, "argmax": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2)] + ) + + if post_array.has_rankseg: + TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)]) + TEST_CASES.append( + [{"rankseg": True, "metric": "iou"}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)] + ) + TEST_CASES.append([{"rankseg": True}, p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]]), p([[[[1.0, 1.0]]]]), (1, 1, 1, 2)]) + class TestAsDiscrete(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -76,6 +90,18 @@ def test_additional(self): out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]])) assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False) + def test_rankseg_argmax_incompatible(self): + with self.assertRaises(ValueError): + AsDiscrete(argmax=True, rankseg=True) + + with self.assertRaises(ValueError): + AsDiscrete(argmax=True)([[[0.3, 0.6]], [[0.7, 0.4]]], rankseg=True) + + def test_rankseg_missing_dependency(self): + with mock.patch("monai.transforms.post.array.has_rankseg", False): + with self.assertRaises(OptionalImportError): + AsDiscrete(rankseg=True)([[[0.3, 0.6]], [[0.7, 0.4]]]) + if __name__ == "__main__": unittest.main() diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index 3c29e820d0..e1ad06c17d 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -12,10 +12,13 @@ from __future__ import annotations import unittest +from unittest import mock from parameterized import parameterized from monai.transforms import AsDiscreted +from monai.transforms.post import array as post_array +from monai.utils import OptionalImportError from tests.test_utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] @@ -66,6 +69,52 @@ ] ) + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": False, "argmax": True}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, + {"pred": p([[[1.0, 0.0]]])}, + (1, 1, 2), + ] + ) + + if post_array.has_rankseg: + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, + {"pred": p([[[1.0, 1.0]]])}, + (1, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "rankseg": [True, False]}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]])}, + {"pred": p([[[1.0, 1.0]]]), "label": p([[[0.0, 1.0]]])}, + (1, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True, "metric": "iou"}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, + {"pred": p([[[1.0, 1.0]]])}, + (1, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True}, + {"pred": p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]])}, + {"pred": p([[[[1.0, 1.0]]]])}, + (1, 1, 1, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -77,6 +126,15 @@ def test_value_shape(self, input_param, test_input, output, expected_shape): assert_allclose(result["label"], output["label"], rtol=1e-3, type_test="tensor") self.assertTupleEqual(result["label"].shape, expected_shape) + def test_rankseg_argmax_incompatible(self): + with self.assertRaises(ValueError): + AsDiscreted(keys="pred", argmax=True, rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}) + + def test_rankseg_missing_dependency(self): + with mock.patch("monai.transforms.post.array.has_rankseg", False): + with self.assertRaises(OptionalImportError): + AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}) + if __name__ == "__main__": unittest.main()