From 45510fd5a072a9c6df996197427c7dfffbaa4365 Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Sat, 30 May 2026 15:58:00 +0800 Subject: [PATCH 1/8] Add optional RankSEG decoding to AsDiscrete Signed-off-by: Zixun Wang --- docs/source/installation.md | 4 ++-- monai/transforms/post/array.py | 30 +++++++++++++++++++++++++++++ monai/transforms/post/dictionary.py | 13 ++++++++++--- requirements-dev.txt | 1 + setup.cfg | 3 +++ 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/docs/source/installation.md b/docs/source/installation.md index 5123bc3e6b..42163c8d6e 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, rankseg] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `rankseg` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..8f605eb535 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -39,15 +39,19 @@ ) from monai.transforms.utils_pytorch_numpy_unification import unravel_index from monai.utils import ( + OptionalImportError, TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, get_equivalent_dtype, look_up_option, + optional_import, ) from monai.utils.type_conversion import convert_to_dst_type +rankseg_module, has_rankseg = optional_import("rankseg") + __all__ = [ "Activations", "AsDiscrete", @@ -142,6 +146,7 @@ class AsDiscrete(Transform): Convert the input tensor/array into discrete values, possible operations are: - `argmax`. + - `rankseg`. - threshold input value to binary values. - convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes). - round the value to the closest integer. @@ -155,6 +160,9 @@ class AsDiscrete(Transform): Defaults to ``None``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. + rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. + RankSEG expects channel-first probability maps and returns a label map. + Defaults to ``False``. kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`. currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. These default to ``0``, ``True``, ``torch.float`` respectively. @@ -183,9 +191,14 @@ def __init__( to_onehot: int | None = None, threshold: float | None = None, rounding: str | None = None, + rankseg: bool = False, **kwargs, ) -> None: + if argmax and rankseg: + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.argmax = argmax + self.rankseg = rankseg + self._rankseg_decoder = None if isinstance(to_onehot, bool): # for backward compatibility raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") self.to_onehot = to_onehot @@ -200,6 +213,7 @@ def __call__( to_onehot: int | None = None, threshold: float | None = None, rounding: str | None = None, + rankseg: bool | None = None, ) -> NdarrayOrTensor: """ Args: @@ -211,6 +225,9 @@ def __call__( Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. + rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. + RankSEG expects channel-first probability maps and returns a label map. + Defaults to ``self.rankseg``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. @@ -220,9 +237,22 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) argmax = self.argmax if argmax is None else argmax + rankseg = self.rankseg if rankseg is None else rankseg + + if argmax and rankseg: + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) + if rankseg: + if not has_rankseg: + raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.") + if self._rankseg_decoder is None: + self._rankseg_decoder = rankseg_module.RankSEG() + # RankSEG expects a batch dimension. + img_t = self._rankseg_decoder.predict(img_t.unsqueeze(0)).squeeze(0) + to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: if not isinstance(to_onehot, int): diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 65fdd22b22..669a60bff5 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -166,6 +166,7 @@ def __init__( to_onehot: Sequence[int | None] | int | None = None, threshold: Sequence[float | None] | float | None = None, rounding: Sequence[str | None] | str | None = None, + rankseg: Sequence[bool] | bool = False, allow_missing_keys: bool = False, **kwargs, ) -> None: @@ -182,6 +183,9 @@ def __init__( rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. + rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. + RankSEG expects channel-first probability maps and returns a label map. It also can be + a sequence of bool, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. kwargs: additional parameters to ``AsDiscrete``. ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. @@ -203,15 +207,18 @@ def __init__( self.threshold.append(flag) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) + self.rankseg = ensure_tuple_rep(rankseg, len(self.keys)) self.converter = AsDiscrete() self.converter.kwargs = kwargs def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, argmax, to_onehot, threshold, rounding in self.key_iterator( - d, self.argmax, self.to_onehot, self.threshold, self.rounding + for key, argmax, to_onehot, threshold, rounding, rankseg in self.key_iterator( + d, self.argmax, self.to_onehot, self.threshold, self.rounding, self.rankseg ): - d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding) + d[key] = self.converter( + d[key], argmax=argmax, to_onehot=to_onehot, threshold=threshold, rounding=rounding, rankseg=rankseg + ) return d diff --git a/requirements-dev.txt b/requirements-dev.txt index eb4429cce7..4aebc49398 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,6 +59,7 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 +rankseg>=0.0.4 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy diff --git a/setup.cfg b/setup.cfg index 724d1eceb3..42d8cbe9e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,6 +90,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 + rankseg>=0.0.4 nibabel = nibabel ninja = @@ -179,6 +180,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0, <5.3.0 +rankseg = + rankseg>=0.0.4 # segment-anything = # segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything From 5b5dfff74abe2100e7433603357e658023c66892 Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Thu, 4 Jun 2026 17:40:43 +0800 Subject: [PATCH 2/8] Switch RankSEG class decoding to functional API Signed-off-by: Zixun Wang --- monai/transforms/post/array.py | 7 ++----- requirements-dev.txt | 2 +- setup.cfg | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8f605eb535..8e548f60e8 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -50,7 +50,7 @@ ) from monai.utils.type_conversion import convert_to_dst_type -rankseg_module, has_rankseg = optional_import("rankseg") +rankseg_fn, has_rankseg = optional_import("rankseg.functional", name="rankseg") __all__ = [ "Activations", @@ -198,7 +198,6 @@ def __init__( raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.argmax = argmax self.rankseg = rankseg - self._rankseg_decoder = None if isinstance(to_onehot, bool): # for backward compatibility raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") self.to_onehot = to_onehot @@ -248,10 +247,8 @@ def __call__( if rankseg: if not has_rankseg: raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.") - if self._rankseg_decoder is None: - self._rankseg_decoder = rankseg_module.RankSEG() # RankSEG expects a batch dimension. - img_t = self._rankseg_decoder.predict(img_t.unsqueeze(0)).squeeze(0) + img_t = rankseg_fn(img_t.unsqueeze(0)).squeeze(0) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: diff --git a/requirements-dev.txt b/requirements-dev.txt index 4aebc49398..19e26298f3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,7 +59,7 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 -rankseg>=0.0.4 +rankseg>=0.0.5 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy diff --git a/setup.cfg b/setup.cfg index 42d8cbe9e8..1d16af8728 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,7 +90,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0, <5.3.0 - rankseg>=0.0.4 + rankseg>=0.0.5 nibabel = nibabel ninja = @@ -181,7 +181,7 @@ huggingface_hub = pyamg = pyamg>=5.0.0, <5.3.0 rankseg = - rankseg>=0.0.4 + rankseg>=0.0.5 # segment-anything = # segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything From c15c03a20cde65f383f138fc774ab18b5424cf90 Mon Sep 17 00:00:00 2001 From: LI Junxing <83260380+Leev1s@users.noreply.github.com> Date: Tue, 9 Jun 2026 00:15:41 +0800 Subject: [PATCH 3/8] Add RankSEG post-transform tests Signed-off-by: LI Junxing <83260380+Leev1s@users.noreply.github.com> --- tests/transforms/test_as_discrete.py | 34 +++++++++++++++++++ tests/transforms/test_as_discreted.py | 48 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index a83870e514..10c124f290 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -12,10 +12,13 @@ from __future__ import annotations import unittest +from unittest import mock from parameterized import parameterized from monai.transforms import AsDiscrete +from monai.transforms.post import array as post_array +from monai.utils import OptionalImportError from tests.test_utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] @@ -63,6 +66,25 @@ [{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)] ) + TEST_CASES.append( + [ + {"rankseg": False, "argmax": True}, + p([[[0.45, 0.55]], [[0.55, 0.45]]]), + p([[[1.0, 0.0]]]), + (1, 1, 2), + ] + ) + + if post_array.has_rankseg: + TEST_CASES.append( + [ + {"rankseg": True}, + p([[[0.45, 0.55]], [[0.55, 0.45]]]), + p([[1.0, 1.0]]), + (1, 2), + ] + ) + class TestAsDiscrete(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -76,6 +98,18 @@ def test_additional(self): out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]])) assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False) + def test_rankseg_argmax_incompatible(self): + with self.assertRaises(ValueError): + AsDiscrete(argmax=True, rankseg=True) + + with self.assertRaises(ValueError): + AsDiscrete(argmax=True)([[[0.45, 0.55]], [[0.55, 0.45]]], rankseg=True) + + def test_rankseg_missing_dependency(self): + with mock.patch("monai.transforms.post.array.has_rankseg", False): + with self.assertRaises(OptionalImportError): + AsDiscrete(rankseg=True)([[[0.45, 0.55]], [[0.55, 0.45]]]) + if __name__ == "__main__": unittest.main() diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index 3c29e820d0..a3f6ef3689 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -12,10 +12,13 @@ from __future__ import annotations import unittest +from unittest import mock from parameterized import parameterized from monai.transforms import AsDiscreted +from monai.transforms.post import array as post_array +from monai.utils import OptionalImportError from tests.test_utils import TEST_NDARRAYS, assert_allclose TEST_CASES = [] @@ -66,6 +69,40 @@ ] ) + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": False, "argmax": True}, + {"pred": p([[[0.45, 0.55]], [[0.55, 0.45]]])}, + {"pred": p([[[1.0, 0.0]]])}, + (1, 1, 2), + ] + ) + + if post_array.has_rankseg: + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True}, + {"pred": p([[[0.45, 0.55]], [[0.55, 0.45]]])}, + {"pred": p([[1.0, 1.0]])}, + (1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": ["pred", "label"], "rankseg": [True, False]}, + { + "pred": p([[[0.45, 0.55]], [[0.55, 0.45]]]), + "label": p([[0.0, 1.0]]), + }, + { + "pred": p([[1.0, 1.0]]), + "label": p([[0.0, 1.0]]), + }, + (1, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -77,6 +114,17 @@ def test_value_shape(self, input_param, test_input, output, expected_shape): assert_allclose(result["label"], output["label"], rtol=1e-3, type_test="tensor") self.assertTupleEqual(result["label"].shape, expected_shape) + def test_rankseg_argmax_incompatible(self): + with self.assertRaises(ValueError): + AsDiscreted(keys="pred", argmax=True, rankseg=True)( + {"pred": [[[0.45, 0.55]], [[0.55, 0.45]]]} + ) + + def test_rankseg_missing_dependency(self): + with mock.patch("monai.transforms.post.array.has_rankseg", False): + with self.assertRaises(OptionalImportError): + AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.45, 0.55]], [[0.55, 0.45]]]}) + if __name__ == "__main__": unittest.main() From 3bc536f094d40897bba0739959a6715affd967e4 Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Mon, 8 Jun 2026 22:05:52 +0800 Subject: [PATCH 4/8] Align RankSEG output shape with argmax Signed-off-by: Zixun Wang --- monai/transforms/post/array.py | 8 ++++++-- tests/transforms/test_as_discrete.py | 4 ++-- tests/transforms/test_as_discreted.py | 12 ++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8e548f60e8..3c474e7b8c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -247,8 +247,12 @@ def __call__( if rankseg: if not has_rankseg: raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.") - # RankSEG expects a batch dimension. - img_t = rankseg_fn(img_t.unsqueeze(0)).squeeze(0) + # Adjust shape to meet RankSEG's [B, C, *spatial] input requirement. + channel_dim = self.kwargs.get("dim", 0) % img_t.ndim + keepdim = self.kwargs.get("keepdim", True) + img_t = rankseg_fn(img_t.movedim(channel_dim, 0).unsqueeze(0)).squeeze(0) + if keepdim: + img_t = img_t.unsqueeze(channel_dim) to_onehot = self.to_onehot if to_onehot is None else to_onehot if to_onehot is not None: diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index 10c124f290..98b4df86e5 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -80,8 +80,8 @@ [ {"rankseg": True}, p([[[0.45, 0.55]], [[0.55, 0.45]]]), - p([[1.0, 1.0]]), - (1, 2), + p([[[1.0, 1.0]]]), + (1, 1, 2), ] ) diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index a3f6ef3689..b78be92124 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -83,8 +83,8 @@ [ {"keys": "pred", "rankseg": True}, {"pred": p([[[0.45, 0.55]], [[0.55, 0.45]]])}, - {"pred": p([[1.0, 1.0]])}, - (1, 2), + {"pred": p([[[1.0, 1.0]]])}, + (1, 1, 2), ] ) @@ -93,13 +93,13 @@ {"keys": ["pred", "label"], "rankseg": [True, False]}, { "pred": p([[[0.45, 0.55]], [[0.55, 0.45]]]), - "label": p([[0.0, 1.0]]), + "label": p([[[0.0, 1.0]]]), }, { - "pred": p([[1.0, 1.0]]), - "label": p([[0.0, 1.0]]), + "pred": p([[[1.0, 1.0]]]), + "label": p([[[0.0, 1.0]]]), }, - (1, 2), + (1, 1, 2), ] ) From 5b2e38707753359d487c079a9558897bee6f3205 Mon Sep 17 00:00:00 2001 From: statmlben Date: Tue, 9 Jun 2026 11:24:06 +0800 Subject: [PATCH 5/8] Update RankSEG docs for post transforms Signed-off-by: statmlben --- monai/transforms/post/array.py | 20 ++++++++++++++++---- monai/transforms/post/dictionary.py | 7 ++++--- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 3c474e7b8c..58168192f5 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -160,8 +160,13 @@ class AsDiscrete(Transform): Defaults to ``None``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. - rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. - RankSEG expects channel-first probability maps and returns a label map. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + RankSEG is applied to a channel-first probability map for one image; ``dim`` identifies the + class/channel dimension and is moved to the front before decoding. For the common MONAI + post-processing input shape ``(C, *spatial)``, use the default ``dim=0``. + The output is a label map. With the default ``keepdim=True``, the output shape is ``(1, *spatial)``; + with ``keepdim=False``, it is ``(*spatial)``. The ``dim`` and ``keepdim`` shape handling is aligned + with ``argmax``. This option is incompatible with ``argmax=True``. Defaults to ``False``. kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`. currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. @@ -181,6 +186,12 @@ class AsDiscrete(Transform): >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]]))) # [[[0.0, 0.0]], [[1.0, 1.0]]] + RankSEG decoding requires the optional ``rankseg`` package: + + >>> transform = AsDiscrete(rankseg=True) + >>> print(transform(np.array([[[0.3, 0.6]], [[0.7, 0.4]]]))) + # [[[1.0, 1.0]]] + """ backend = [TransformBackends.TORCH] @@ -224,8 +235,9 @@ def __call__( Defaults to ``self.to_onehot``. threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value. Defaults to ``self.threshold``. - rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. - RankSEG expects channel-first probability maps and returns a label map. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + Applies RankSEG to a channel-first probability map by default and uses the same ``dim`` and + ``keepdim`` shape handling as ``argmax``. This option is incompatible with ``argmax=True``. Defaults to ``self.rankseg``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 669a60bff5..ff5e049279 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -183,9 +183,10 @@ def __init__( rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. it also can be a sequence of str or None, each element corresponds to a key in ``keys``. - rankseg: whether to apply RankSEG decoding. Requires the optional ``rankseg`` package. - RankSEG expects channel-first probability maps and returns a label map. It also can be - a sequence of bool, each element corresponds to a key in ``keys``. + rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. + RankSEG expects channel-first probability maps for one image. It also can be a sequence of bool, + each element corresponds to a key in ``keys``. Uses the same ``dim`` and ``keepdim`` shape handling + as ``argmax``. This option is incompatible with ``argmax=True``. allow_missing_keys: don't raise exception if key is missing. kwargs: additional parameters to ``AsDiscrete``. ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. From b10b7e17db45db6adc503ad782eee889d2d0fe36 Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Tue, 9 Jun 2026 13:33:26 +0800 Subject: [PATCH 6/8] Update RankSEG test cases Signed-off-by: Zixun Wang --- monai/transforms/post/dictionary.py | 4 +++- tests/transforms/test_as_discrete.py | 8 ++++---- tests/transforms/test_as_discreted.py | 10 +++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index ff5e049279..a551449436 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -195,6 +195,9 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) + self.rankseg = ensure_tuple_rep(rankseg, len(self.keys)) + if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg)): + raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.to_onehot = [] for flag in ensure_tuple_rep(to_onehot, len(self.keys)): if isinstance(flag, bool): @@ -208,7 +211,6 @@ def __init__( self.threshold.append(flag) self.rounding = ensure_tuple_rep(rounding, len(self.keys)) - self.rankseg = ensure_tuple_rep(rankseg, len(self.keys)) self.converter = AsDiscrete() self.converter.kwargs = kwargs diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index 98b4df86e5..f55bbc9589 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -69,7 +69,7 @@ TEST_CASES.append( [ {"rankseg": False, "argmax": True}, - p([[[0.45, 0.55]], [[0.55, 0.45]]]), + p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2), ] @@ -79,7 +79,7 @@ TEST_CASES.append( [ {"rankseg": True}, - p([[[0.45, 0.55]], [[0.55, 0.45]]]), + p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2), ] @@ -103,12 +103,12 @@ def test_rankseg_argmax_incompatible(self): AsDiscrete(argmax=True, rankseg=True) with self.assertRaises(ValueError): - AsDiscrete(argmax=True)([[[0.45, 0.55]], [[0.55, 0.45]]], rankseg=True) + AsDiscrete(argmax=True)([[[0.3, 0.6]], [[0.7, 0.4]]], rankseg=True) def test_rankseg_missing_dependency(self): with mock.patch("monai.transforms.post.array.has_rankseg", False): with self.assertRaises(OptionalImportError): - AsDiscrete(rankseg=True)([[[0.45, 0.55]], [[0.55, 0.45]]]) + AsDiscrete(rankseg=True)([[[0.3, 0.6]], [[0.7, 0.4]]]) if __name__ == "__main__": diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index b78be92124..4365fde01d 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -72,7 +72,7 @@ TEST_CASES.append( [ {"keys": "pred", "rankseg": False, "argmax": True}, - {"pred": p([[[0.45, 0.55]], [[0.55, 0.45]]])}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, {"pred": p([[[1.0, 0.0]]])}, (1, 1, 2), ] @@ -82,7 +82,7 @@ TEST_CASES.append( [ {"keys": "pred", "rankseg": True}, - {"pred": p([[[0.45, 0.55]], [[0.55, 0.45]]])}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, {"pred": p([[[1.0, 1.0]]])}, (1, 1, 2), ] @@ -92,7 +92,7 @@ [ {"keys": ["pred", "label"], "rankseg": [True, False]}, { - "pred": p([[[0.45, 0.55]], [[0.55, 0.45]]]), + "pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]]), }, { @@ -117,13 +117,13 @@ def test_value_shape(self, input_param, test_input, output, expected_shape): def test_rankseg_argmax_incompatible(self): with self.assertRaises(ValueError): AsDiscreted(keys="pred", argmax=True, rankseg=True)( - {"pred": [[[0.45, 0.55]], [[0.55, 0.45]]]} + {"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]} ) def test_rankseg_missing_dependency(self): with mock.patch("monai.transforms.post.array.has_rankseg", False): with self.assertRaises(OptionalImportError): - AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.45, 0.55]], [[0.55, 0.45]]]}) + AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}) if __name__ == "__main__": From 1866926bc8212a1a7906b88ec3030c2d64ee2980 Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Tue, 9 Jun 2026 21:51:59 +0800 Subject: [PATCH 7/8] Apply codeformat fixes and tighten validation Signed-off-by: Zixun Wang --- monai/transforms/post/dictionary.py | 2 +- tests/transforms/test_as_discrete.py | 16 ++-------------- tests/transforms/test_as_discreted.py | 14 +++----------- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index a551449436..c975c9e75c 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -196,7 +196,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.argmax = ensure_tuple_rep(argmax, len(self.keys)) self.rankseg = ensure_tuple_rep(rankseg, len(self.keys)) - if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg)): + if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg, strict=True)): raise ValueError("`rankseg=True` is incompatible with `argmax=True`.") self.to_onehot = [] for flag in ensure_tuple_rep(to_onehot, len(self.keys)): diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index f55bbc9589..e3f70000ec 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -67,23 +67,11 @@ ) TEST_CASES.append( - [ - {"rankseg": False, "argmax": True}, - p([[[0.3, 0.6]], [[0.7, 0.4]]]), - p([[[1.0, 0.0]]]), - (1, 1, 2), - ] + [{"rankseg": False, "argmax": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2)] ) if post_array.has_rankseg: - TEST_CASES.append( - [ - {"rankseg": True}, - p([[[0.3, 0.6]], [[0.7, 0.4]]]), - p([[[1.0, 1.0]]]), - (1, 1, 2), - ] - ) + TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)]) class TestAsDiscrete(unittest.TestCase): diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index 4365fde01d..9832efc82f 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -91,14 +91,8 @@ TEST_CASES.append( [ {"keys": ["pred", "label"], "rankseg": [True, False]}, - { - "pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), - "label": p([[[0.0, 1.0]]]), - }, - { - "pred": p([[[1.0, 1.0]]]), - "label": p([[[0.0, 1.0]]]), - }, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]])}, + {"pred": p([[[1.0, 1.0]]]), "label": p([[[0.0, 1.0]]])}, (1, 1, 2), ] ) @@ -116,9 +110,7 @@ def test_value_shape(self, input_param, test_input, output, expected_shape): def test_rankseg_argmax_incompatible(self): with self.assertRaises(ValueError): - AsDiscreted(keys="pred", argmax=True, rankseg=True)( - {"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]} - ) + AsDiscreted(keys="pred", argmax=True, rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}) def test_rankseg_missing_dependency(self): with mock.patch("monai.transforms.post.array.has_rankseg", False): From f08231b483f13c70f09cba848d5f6388ffb309fc Mon Sep 17 00:00:00 2001 From: Zixun Wang Date: Wed, 10 Jun 2026 12:09:09 +0800 Subject: [PATCH 8/8] Add 3D test cases and expose RankSEG metric argument Signed-off-by: Zixun Wang --- monai/transforms/post/array.py | 11 +++++++---- monai/transforms/post/dictionary.py | 4 ++-- tests/transforms/test_as_discrete.py | 4 ++++ tests/transforms/test_as_discreted.py | 18 ++++++++++++++++++ 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 58168192f5..90d3474f86 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -169,8 +169,8 @@ class AsDiscrete(Transform): with ``argmax``. This option is incompatible with ``argmax=True``. Defaults to ``False``. kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`. - currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. - These default to ``0``, ``True``, ``torch.float`` respectively. + currently ``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters + will be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively. Example: @@ -237,7 +237,8 @@ def __call__( Defaults to ``self.threshold``. rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package. Applies RankSEG to a channel-first probability map by default and uses the same ``dim`` and - ``keepdim`` shape handling as ``argmax``. This option is incompatible with ``argmax=True``. + ``keepdim`` shape handling as ``argmax``. The RankSEG ``metric`` can be specified in ``kwargs``. + This option is incompatible with ``argmax=True``. Defaults to ``self.rankseg``. rounding: if not None, round the data according to the specified option, available options: ["torchrounding"]. @@ -262,7 +263,9 @@ def __call__( # Adjust shape to meet RankSEG's [B, C, *spatial] input requirement. channel_dim = self.kwargs.get("dim", 0) % img_t.ndim keepdim = self.kwargs.get("keepdim", True) - img_t = rankseg_fn(img_t.movedim(channel_dim, 0).unsqueeze(0)).squeeze(0) + img_t = rankseg_fn( + img_t.movedim(channel_dim, 0).unsqueeze(0), metric=self.kwargs.get("metric", "dice") + ).squeeze(0) if keepdim: img_t = img_t.unsqueeze(channel_dim) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index c975c9e75c..4bf3a53fb9 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -189,8 +189,8 @@ def __init__( as ``argmax``. This option is incompatible with ``argmax=True``. allow_missing_keys: don't raise exception if key is missing. kwargs: additional parameters to ``AsDiscrete``. - ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. - These default to ``0``, ``True``, ``torch.float`` respectively. + ``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters will + be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively. """ super().__init__(keys, allow_missing_keys) diff --git a/tests/transforms/test_as_discrete.py b/tests/transforms/test_as_discrete.py index e3f70000ec..92721642dd 100644 --- a/tests/transforms/test_as_discrete.py +++ b/tests/transforms/test_as_discrete.py @@ -72,6 +72,10 @@ if post_array.has_rankseg: TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)]) + TEST_CASES.append( + [{"rankseg": True, "metric": "iou"}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)] + ) + TEST_CASES.append([{"rankseg": True}, p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]]), p([[[[1.0, 1.0]]]]), (1, 1, 1, 2)]) class TestAsDiscrete(unittest.TestCase): diff --git a/tests/transforms/test_as_discreted.py b/tests/transforms/test_as_discreted.py index 9832efc82f..e1ad06c17d 100644 --- a/tests/transforms/test_as_discreted.py +++ b/tests/transforms/test_as_discreted.py @@ -97,6 +97,24 @@ ] ) + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True, "metric": "iou"}, + {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])}, + {"pred": p([[[1.0, 1.0]]])}, + (1, 1, 2), + ] + ) + + TEST_CASES.append( + [ + {"keys": "pred", "rankseg": True}, + {"pred": p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]])}, + {"pred": p([[[[1.0, 1.0]]]])}, + (1, 1, 1, 2), + ] + ) + class TestAsDiscreted(unittest.TestCase): @parameterized.expand(TEST_CASES)