Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, rankseg]
```

which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `rankseg` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
50 changes: 48 additions & 2 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import (
OptionalImportError,
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
optional_import,
)
from monai.utils.type_conversion import convert_to_dst_type

rankseg_fn, has_rankseg = optional_import("rankseg.functional", name="rankseg")

__all__ = [
"Activations",
"AsDiscrete",
Expand Down Expand Up @@ -142,6 +146,7 @@ class AsDiscrete(Transform):
Convert the input tensor/array into discrete values, possible operations are:

- `argmax`.
- `rankseg`.
- threshold input value to binary values.
- convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes).
- round the value to the closest integer.
Expand All @@ -155,9 +160,17 @@ class AsDiscrete(Transform):
Defaults to ``None``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
RankSEG is applied to a channel-first probability map for one image; ``dim`` identifies the
class/channel dimension and is moved to the front before decoding. For the common MONAI
post-processing input shape ``(C, *spatial)``, use the default ``dim=0``.
The output is a label map. With the default ``keepdim=True``, the output shape is ``(1, *spatial)``;
with ``keepdim=False``, it is ``(*spatial)``. The ``dim`` and ``keepdim`` shape handling is aligned
with ``argmax``. This option is incompatible with ``argmax=True``.
Defaults to ``False``.
kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.
currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
These default to ``0``, ``True``, ``torch.float`` respectively.
currently ``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters
will be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively.

Example:

Expand All @@ -173,6 +186,12 @@ class AsDiscrete(Transform):
>>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
# [[[0.0, 0.0]], [[1.0, 1.0]]]

RankSEG decoding requires the optional ``rankseg`` package:

>>> transform = AsDiscrete(rankseg=True)
>>> print(transform(np.array([[[0.3, 0.6]], [[0.7, 0.4]]])))
# [[[1.0, 1.0]]]

"""

backend = [TransformBackends.TORCH]
Expand All @@ -183,9 +202,13 @@ def __init__(
to_onehot: int | None = None,
threshold: float | None = None,
rounding: str | None = None,
rankseg: bool = False,
**kwargs,
) -> None:
if argmax and rankseg:
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
self.argmax = argmax
self.rankseg = rankseg
if isinstance(to_onehot, bool): # for backward compatibility
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
self.to_onehot = to_onehot
Expand All @@ -200,6 +223,7 @@ def __call__(
to_onehot: int | None = None,
threshold: float | None = None,
rounding: str | None = None,
rankseg: bool | None = None,
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -211,6 +235,11 @@ def __call__(
Defaults to ``self.to_onehot``.
threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
Defaults to ``self.threshold``.
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
Applies RankSEG to a channel-first probability map by default and uses the same ``dim`` and
``keepdim`` shape handling as ``argmax``. The RankSEG ``metric`` can be specified in ``kwargs``.
This option is incompatible with ``argmax=True``.
Defaults to ``self.rankseg``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].

Expand All @@ -220,9 +249,26 @@ def __call__(
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor)
argmax = self.argmax if argmax is None else argmax
rankseg = self.rankseg if rankseg is None else rankseg

if argmax and rankseg:
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")

if argmax:
img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

if rankseg:
if not has_rankseg:
raise OptionalImportError("`rankseg=True` requires the `rankseg` package, but it is not installed.")
# Adjust shape to meet RankSEG's [B, C, *spatial] input requirement.
channel_dim = self.kwargs.get("dim", 0) % img_t.ndim
keepdim = self.kwargs.get("keepdim", True)
img_t = rankseg_fn(
img_t.movedim(channel_dim, 0).unsqueeze(0), metric=self.kwargs.get("metric", "dice")
).squeeze(0)
if keepdim:
img_t = img_t.unsqueeze(channel_dim)

to_onehot = self.to_onehot if to_onehot is None else to_onehot
if to_onehot is not None:
if not isinstance(to_onehot, int):
Expand Down
20 changes: 15 additions & 5 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
to_onehot: Sequence[int | None] | int | None = None,
threshold: Sequence[float | None] | float | None = None,
rounding: Sequence[str | None] | str | None = None,
rankseg: Sequence[bool] | bool = False,
allow_missing_keys: bool = False,
**kwargs,
) -> None:
Expand All @@ -182,14 +183,21 @@ def __init__(
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"]. it also can be a sequence of str or None,
each element corresponds to a key in ``keys``.
rankseg: whether to apply RankSEG decoding. Requires installing the optional ``rankseg`` package.
RankSEG expects channel-first probability maps for one image. It also can be a sequence of bool,
each element corresponds to a key in ``keys``. Uses the same ``dim`` and ``keepdim`` shape handling
as ``argmax``. This option is incompatible with ``argmax=True``.
allow_missing_keys: don't raise exception if key is missing.
kwargs: additional parameters to ``AsDiscrete``.
``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
These default to ``0``, ``True``, ``torch.float`` respectively.
``dim``, ``keepdim``, ``dtype``, and RankSEG ``metric`` are supported, unrecognized parameters will
be ignored. These default to ``0``, ``True``, ``torch.float``, and ``"dice"`` respectively.

"""
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.rankseg = ensure_tuple_rep(rankseg, len(self.keys))
if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg, strict=True)):
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
self.to_onehot = []
for flag in ensure_tuple_rep(to_onehot, len(self.keys)):
if isinstance(flag, bool):
Expand All @@ -208,10 +216,12 @@ def __init__(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, argmax, to_onehot, threshold, rounding in self.key_iterator(
d, self.argmax, self.to_onehot, self.threshold, self.rounding
for key, argmax, to_onehot, threshold, rounding, rankseg in self.key_iterator(
d, self.argmax, self.to_onehot, self.threshold, self.rounding, self.rankseg
):
d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding)
d[key] = self.converter(
d[key], argmax=argmax, to_onehot=to_onehot, threshold=threshold, rounding=rounding, rankseg=rankseg
)
return d


Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ lpips==0.1.4
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
rankseg>=0.0.5
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
onnx_graphsurgeon
polygraphy
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ all =
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
rankseg>=0.0.5
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -179,6 +180,8 @@ huggingface_hub =
huggingface_hub
pyamg =
pyamg>=5.0.0, <5.3.0
rankseg =
rankseg>=0.0.5
# segment-anything =
# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything

Expand Down
26 changes: 26 additions & 0 deletions tests/transforms/test_as_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from __future__ import annotations

import unittest
from unittest import mock

from parameterized import parameterized

from monai.transforms import AsDiscrete
from monai.transforms.post import array as post_array
from monai.utils import OptionalImportError
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
Expand Down Expand Up @@ -63,6 +66,17 @@
[{"rounding": "torchrounding"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)]
)

TEST_CASES.append(
[{"rankseg": False, "argmax": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2)]
)

if post_array.has_rankseg:
TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)])
TEST_CASES.append(
[{"rankseg": True, "metric": "iou"}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)]
)
TEST_CASES.append([{"rankseg": True}, p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]]), p([[[[1.0, 1.0]]]]), (1, 1, 1, 2)])


class TestAsDiscrete(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand All @@ -76,6 +90,18 @@ def test_additional(self):
out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]]))
assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False)

def test_rankseg_argmax_incompatible(self):
with self.assertRaises(ValueError):
AsDiscrete(argmax=True, rankseg=True)

with self.assertRaises(ValueError):
AsDiscrete(argmax=True)([[[0.3, 0.6]], [[0.7, 0.4]]], rankseg=True)

def test_rankseg_missing_dependency(self):
with mock.patch("monai.transforms.post.array.has_rankseg", False):
with self.assertRaises(OptionalImportError):
AsDiscrete(rankseg=True)([[[0.3, 0.6]], [[0.7, 0.4]]])


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions tests/transforms/test_as_discreted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from __future__ import annotations

import unittest
from unittest import mock

from parameterized import parameterized

from monai.transforms import AsDiscreted
from monai.transforms.post import array as post_array
from monai.utils import OptionalImportError
from tests.test_utils import TEST_NDARRAYS, assert_allclose

TEST_CASES = []
Expand Down Expand Up @@ -66,6 +69,52 @@
]
)

TEST_CASES.append(
[
{"keys": "pred", "rankseg": False, "argmax": True},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])},
{"pred": p([[[1.0, 0.0]]])},
(1, 1, 2),
]
)

if post_array.has_rankseg:
TEST_CASES.append(
[
{"keys": "pred", "rankseg": True},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])},
{"pred": p([[[1.0, 1.0]]])},
(1, 1, 2),
]
)

TEST_CASES.append(
[
{"keys": ["pred", "label"], "rankseg": [True, False]},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]])},
{"pred": p([[[1.0, 1.0]]]), "label": p([[[0.0, 1.0]]])},
(1, 1, 2),
]
)

TEST_CASES.append(
[
{"keys": "pred", "rankseg": True, "metric": "iou"},
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]])},
{"pred": p([[[1.0, 1.0]]])},
(1, 1, 2),
]
)

TEST_CASES.append(
[
{"keys": "pred", "rankseg": True},
{"pred": p([[[[0.3, 0.6]]], [[[0.7, 0.4]]]])},
{"pred": p([[[[1.0, 1.0]]]])},
(1, 1, 1, 2),
]
)


class TestAsDiscreted(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand All @@ -77,6 +126,15 @@ def test_value_shape(self, input_param, test_input, output, expected_shape):
assert_allclose(result["label"], output["label"], rtol=1e-3, type_test="tensor")
self.assertTupleEqual(result["label"].shape, expected_shape)

def test_rankseg_argmax_incompatible(self):
with self.assertRaises(ValueError):
AsDiscreted(keys="pred", argmax=True, rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]})

def test_rankseg_missing_dependency(self):
with mock.patch("monai.transforms.post.array.has_rankseg", False):
with self.assertRaises(OptionalImportError):
AsDiscreted(keys="pred", rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]})


if __name__ == "__main__":
unittest.main()
Loading