From 59a5099fa8e208ecf325f5a1ac33b4e13f8e8c54 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Fri, 5 Jun 2026 16:28:53 +0100 Subject: [PATCH 1/2] feat: add ChannelWise wrappers for independent channel augmentations (#8311) --- monai/transforms/__init__.py | 8 +++ monai/transforms/utility/array.py | 78 ++++++++++++++++++++++++++ monai/transforms/utility/dictionary.py | 74 ++++++++++++++++++++++++ tests/test_channel_wise.py | 50 +++++++++++++++++ tests/test_channel_wised.py | 50 +++++++++++++++++ 5 files changed, 260 insertions(+) create mode 100644 tests/test_channel_wise.py create mode 100644 tests/test_channel_wised.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c7ac4b77e6..f9be2e5b45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,6 +518,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -536,6 +537,7 @@ RandIdentity, RandImageFilter, RandLambda, + RandChannelWise, RandTorchIO, RandTorchVision, RemoveRepeatedChannel, @@ -568,6 +570,9 @@ CastToTyped, CastToTypeD, CastToTypeDict, + ChannelWised, + ChannelWiseD, + ChannelWiseDict, ClassesToIndicesd, ClassesToIndicesD, ClassesToIndicesDict, @@ -631,6 +636,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandChannelWised, + RandChannelWiseD, + RandChannelWiseDict, RandTorchIOd, RandTorchIOD, RandTorchIODict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ed4b149e6b..ee298fcb60 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -81,6 +81,8 @@ "EnsureType", "RepeatChannel", "RemoveRepeatedChannel", + "ChannelWise", + "RandChannelWise", "SplitDim", "CastToType", "ToTensor", @@ -288,6 +290,82 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return out +class ChannelWise(Transform): + """ + Apply a given transform to each channel of the input array independently and + concatenate the results back along the channel dimension. + + Args: + transform: a callable transform to apply to each channel. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable) -> None: + self.transform = transform + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if img.shape[0] == 0: + return img + + results = [] + for i in range(img.shape[0]): + res = self.transform(img[[i], ...]) + results.append(res) + + if isinstance(img, torch.Tensor): + return torch.cat(results, dim=0) + return np.concatenate(results, axis=0) + + +class RandChannelWise(RandomizableTransform): + """ + Randomizable version of :py:class:`monai.transforms.ChannelWise`, the input + `transform` will be applied independently to each channel. + + Args: + transform: a callable transform to apply to each channel. + prob: probability of applying the transform to the entire image. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable, prob: float = 1.0) -> None: + RandomizableTransform.__init__(self, prob) + self.transform = transform + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise: + super().set_random_state(seed, state) + if hasattr(self.transform, "set_random_state"): + self.transform.set_random_state(seed, state) + return self + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if randomize: + self.randomize(None) + if not self._do_transform: + return img + + if img.shape[0] == 0: + return img + + results = [] + for i in range(img.shape[0]): + res = self.transform(img[[i], ...]) + results.append(res) + + if isinstance(img, torch.Tensor): + return torch.cat(results, dim=0) + return np.concatenate(results, axis=0) + + + class SplitDim(Transform, MultiSampleTrait): """ Given an image of size X along a certain dimension, return a list of length X containing diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd24a3880..899207f74f 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -38,6 +38,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -52,6 +53,7 @@ LabelToMask, Lambda, MapLabelValue, + RandChannelWise, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -88,6 +90,9 @@ "ConcatItemsD", "ConcatItemsDict", "ConcatItemsd", + "ChannelWiseD", + "ChannelWiseDict", + "ChannelWised", "ConvertToMultiChannelBasedOnBratsClassesD", "ConvertToMultiChannelBasedOnBratsClassesDict", "ConvertToMultiChannelBasedOnBratsClassesd", @@ -131,6 +136,9 @@ "FlattenSubKeysd", "FlattenSubKeysD", "FlattenSubKeysDict", + "RandChannelWiseD", + "RandChannelWiseDict", + "RandChannelWised", "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", @@ -338,6 +346,70 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ChannelWised(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ChannelWise`. + """ + + backend = ChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.converter = ChannelWise(transform=transform) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class RandChannelWised(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandChannelWise`. + """ + + backend = RandChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, prob: float = 1.0, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + prob: probability of applying the transform to the entire image. + allow_missing_keys: don't raise exception if key is missing. + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.converter = RandChannelWise(transform=transform, prob=1.0) + + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandChannelWised: + super().set_random_state(seed, state) + if hasattr(self.converter, "set_random_state"): + self.converter.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + for key in self.key_iterator(d): + d[key] = self.converter(d[key], randomize=False) + return d + + class SplitDimd(MapTransform, MultiSampleTrait): backend = SplitDim.backend @@ -2032,6 +2104,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N AsChannelLastD = AsChannelLastDict = AsChannelLastd EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld +ChannelWiseD = ChannelWiseDict = ChannelWised +RandChannelWiseD = RandChannelWiseDict = RandChannelWised RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitDimD = SplitDimDict = SplitDimd CastToTypeD = CastToTypeDict = CastToTyped diff --git a/tests/test_channel_wise.py b/tests/test_channel_wise.py new file mode 100644 index 0000000000..e3c12f698f --- /dev/null +++ b/tests/test_channel_wise.py @@ -0,0 +1,50 @@ +import unittest + +import numpy as np +import torch + +from monai.transforms import ChannelWise, RandChannelWise, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +class TestChannelWise(unittest.TestCase): + def test_channel_wise_deterministic(self): + # Test applying a deterministic transform channel-wise + data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) # shape (2, 2, 2) + + # ScaleIntensity applies to the whole input array independently + transform = ChannelWise(transform=ScaleIntensity()) + out = transform(data) + + # Channel 0 scaled + np.testing.assert_allclose(out[0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + # Channel 1 scaled + np.testing.assert_allclose(out[1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + self.assertEqual(out.shape, data.shape) + + def test_rand_channel_wise(self): + # Test applying a randomized transform channel-wise + data = np.zeros((3, 4, 4)) + + set_determinism(seed=0) + # Apply random noise with high standard deviation to see the difference + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + # All channels should have different noise values + self.assertFalse(np.allclose(out[0], out[1])) + self.assertFalse(np.allclose(out[1], out[2])) + self.assertFalse(np.allclose(out[0], out[2])) + + # Output shape should be exactly the same + self.assertEqual(out.shape, data.shape) + + def test_prob_zero(self): + # Test when RandChannelWise prob is 0.0 + data = np.zeros((2, 2, 2)) + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out, data) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_channel_wised.py b/tests/test_channel_wised.py new file mode 100644 index 0000000000..8ed56ed7df --- /dev/null +++ b/tests/test_channel_wised.py @@ -0,0 +1,50 @@ +import unittest + +import numpy as np +import torch + +from monai.transforms import ChannelWised, RandChannelWised, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +class TestChannelWised(unittest.TestCase): + def test_channel_wise_deterministic(self): + # Test applying a deterministic transform channel-wise + data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} # shape (2, 2, 2) + + # ScaleIntensity applies to the whole input array independently + transform = ChannelWised(keys=["image"], transform=ScaleIntensity()) + out = transform(data) + + # Channel 0 scaled + np.testing.assert_allclose(out["image"][0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + # Channel 1 scaled + np.testing.assert_allclose(out["image"][1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + self.assertEqual(out["image"].shape, data["image"].shape) + + def test_rand_channel_wise(self): + # Test applying a randomized transform channel-wise + data = {"image": np.zeros((3, 4, 4))} + + set_determinism(seed=0) + # Apply random noise with high standard deviation to see the difference + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + # All channels should have different noise values + self.assertFalse(np.allclose(out["image"][0], out["image"][1])) + self.assertFalse(np.allclose(out["image"][1], out["image"][2])) + self.assertFalse(np.allclose(out["image"][0], out["image"][2])) + + # Output shape should be exactly the same + self.assertEqual(out["image"].shape, data["image"].shape) + + def test_prob_zero(self): + # Test when RandChannelWised prob is 0.0 + data = {"image": np.zeros((2, 2, 2))} + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out["image"], data["image"]) + +if __name__ == "__main__": + unittest.main() From 2781305c3f1328f43861226ca83b8e61a0290ffb Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sun, 7 Jun 2026 22:29:13 +0100 Subject: [PATCH 2/2] Address channel-wise transform review feedback Signed-off-by: ugbotueferhire --- monai/transforms/utility/array.py | 169 ++++++++++++++++++++----- monai/transforms/utility/dictionary.py | 59 ++++++++- tests/test_channel_wise.py | 89 +++++++++---- tests/test_channel_wised.py | 68 ++++++---- 4 files changed, 304 insertions(+), 81 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ee298fcb60..a53a5c0eea 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -292,33 +292,125 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: class ChannelWise(Transform): """ - Apply a given transform to each channel of the input array independently and - concatenate the results back along the channel dimension. - + Apply a transform to each channel of a channel-first array independently. + Args: - transform: a callable transform to apply to each channel. + transform: Callable transform to apply to each channel. The transform receives + a single-channel array or tensor with a singleton leading channel dimension. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, transform: Callable) -> None: + """ + Args: + transform: Callable transform to apply to each channel. + """ self.transform = transform - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + @staticmethod + def _normalize_channel_result( + result: NdarrayOrTensor, expected_ndim: int, channel_index: int + ) -> NdarrayOrTensor: """ - Apply the transform to `img`. + Ensure a per-channel transform result has a singleton channel dimension. + + Args: + result: Output from applying the wrapped transform to one input channel. + expected_ndim: Expected dimensionality of the channel-first output. + channel_index: Index of the input channel being transformed. + + Returns: + The transform result with a singleton leading channel dimension. + + Raises: + ValueError: If the result cannot be concatenated along the channel dimension. + TypeError: If the result is not a NumPy array or torch tensor. + """ + if isinstance(result, torch.Tensor): + if result.ndim == expected_ndim - 1: + return result.unsqueeze(0) + if result.ndim == expected_ndim and result.shape[0] == 1: + return result + elif isinstance(result, np.ndarray): + if result.ndim == expected_ndim - 1: + return np.expand_dims(result, axis=0) + if result.ndim == expected_ndim and result.shape[0] == 1: + return result + else: + raise TypeError( + f"Channel {channel_index} transform output must be a NumPy array or torch tensor, " + f"got {type(result).__name__}." + ) + + raise ValueError( + f"Channel {channel_index} transform output must preserve a singleton leading channel " + f"dimension or return a squeezed channel with {expected_ndim - 1} dimensions, " + f"got shape {tuple(result.shape)}." + ) + + @staticmethod + def _concatenate_channel_results(results: list[NdarrayOrTensor]) -> NdarrayOrTensor: + """ + Concatenate normalized per-channel transform results. + + Args: + results: Sequence of normalized per-channel NumPy arrays or torch tensors. + + Returns: + A NumPy array or torch tensor concatenated along the leading channel dimension. + + Raises: + TypeError: If results contain mixed array and tensor types. + """ + if all(isinstance(result, torch.Tensor) for result in results): + return torch.cat(results, dim=0) + if all(isinstance(result, np.ndarray) for result in results): + return np.concatenate(results, axis=0) + raise TypeError("All channel-wise transform outputs must have the same array or tensor type.") + + @classmethod + def _apply_channel_wise(cls, img: NdarrayOrTensor, transform: Callable) -> NdarrayOrTensor: + """ + Apply a callable independently to each channel and concatenate the outputs. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + transform: Callable transform to apply to each channel. + + Returns: + A NumPy array or torch tensor containing the transformed channels. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel has an invalid type or mixed output types are produced. """ + if len(img.shape) == 0: + raise ValueError("Image must have a channel dimension.") if img.shape[0] == 0: return img - - results = [] - for i in range(img.shape[0]): - res = self.transform(img[[i], ...]) - results.append(res) - - if isinstance(img, torch.Tensor): - return torch.cat(results, dim=0) - return np.concatenate(results, axis=0) + + results = [ + cls._normalize_channel_result(transform(img[[i], ...]), expected_ndim=len(img.shape), channel_index=i) + for i in range(img.shape[0]) + ] + return cls._concatenate_channel_results(results) + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the wrapped transform independently to each channel. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + + Returns: + A NumPy array or torch tensor containing the transformed channels. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel is not a NumPy array or torch tensor. + """ + return self._apply_channel_wise(img, self.transform) class RandChannelWise(RandomizableTransform): @@ -327,17 +419,33 @@ class RandChannelWise(RandomizableTransform): `transform` will be applied independently to each channel. Args: - transform: a callable transform to apply to each channel. - prob: probability of applying the transform to the entire image. + transform: Callable transform to apply to each channel. If it defines + ``set_random_state``, the state is synchronized by this wrapper. + prob: Probability of applying the transform to the entire image. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, transform: Callable, prob: float = 1.0) -> None: + """ + Args: + transform: Callable transform to apply to each channel. + prob: Probability of applying the transform to the entire image. + """ RandomizableTransform.__init__(self, prob) self.transform = transform def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise: + """ + Set the random state for this transform and its wrapped transform. + + Args: + seed: Seed to use for a new NumPy random state. + state: Existing NumPy random state to use. If provided, it takes precedence over ``seed``. + + Returns: + This ``RandChannelWise`` instance. + """ super().set_random_state(seed, state) if hasattr(self.transform, "set_random_state"): self.transform.set_random_state(seed, state) @@ -345,25 +453,26 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ - Apply the transform to `img`. + Apply the wrapped transform independently to each channel when sampled. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + randomize: Whether to sample this wrapper's random state before applying the transform. + + Returns: + A NumPy array or torch tensor containing the transformed channels, or the input unchanged + when this transform is not sampled. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel is not a NumPy array or torch tensor. """ if randomize: self.randomize(None) if not self._do_transform: return img - - if img.shape[0] == 0: - return img - - results = [] - for i in range(img.shape[0]): - res = self.transform(img[[i], ...]) - results.append(res) - - if isinstance(img, torch.Tensor): - return torch.cat(results, dim=0) - return np.concatenate(results, axis=0) + return ChannelWise._apply_channel_wise(img, self.transform) class SplitDim(Transform, MultiSampleTrait): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 899207f74f..5c7b7b7b6c 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -349,6 +349,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N class ChannelWised(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ChannelWise`. + + This transform stores a ``ChannelWise`` converter and applies it to each selected key. + + Args: + keys: Keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform`. + transform: Callable transform to apply independently to each channel. + allow_missing_keys: Don't raise an exception if a key is missing. """ backend = ChannelWise.backend @@ -365,6 +373,20 @@ def __init__(self, keys: KeysCollection, transform: Callable, allow_missing_keys self.converter = ChannelWise(transform=transform) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Apply the channel-wise converter to each selected item. + + Args: + data: Mapping containing channel-first arrays or tensors for the configured keys. + + Returns: + A dictionary with transformed values for the configured keys and unchanged values for other keys. + + Raises: + KeyError: If a configured key is missing and ``allow_missing_keys`` is ``False``. + TypeError: If ``data`` cannot be copied to a dictionary or a transformed channel has an invalid type. + ValueError: If an input or transformed channel has an incompatible shape. + """ d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) @@ -374,6 +396,16 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N class RandChannelWised(MapTransform, RandomizableTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandChannelWise`. + + This transform samples once per call and, when selected, delegates per-channel + randomized processing to its ``RandChannelWise`` converter. + + Args: + keys: Keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform`. + transform: Callable transform to apply independently to each channel. + prob: Probability of applying the transform to the selected items. + allow_missing_keys: Don't raise an exception if a key is missing. """ backend = RandChannelWise.backend @@ -394,17 +426,42 @@ def __init__(self, keys: KeysCollection, transform: Callable, prob: float = 1.0, def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None ) -> RandChannelWised: + """ + Set the random state for this transform and its converter. + + Args: + seed: Seed to use for a new NumPy random state. + state: Existing NumPy random state to use. If provided, it takes precedence over ``seed``. + + Returns: + This ``RandChannelWised`` instance. + """ super().set_random_state(seed, state) if hasattr(self.converter, "set_random_state"): self.converter.set_random_state(seed, state) return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Apply the randomized channel-wise converter to each selected item when sampled. + + Args: + data: Mapping containing channel-first arrays or tensors for the configured keys. + + Returns: + A dictionary with transformed values for the configured keys when sampled, or a shallow + copy of ``data`` when not sampled. + + Raises: + KeyError: If a configured key is missing and ``allow_missing_keys`` is ``False``. + TypeError: If ``data`` cannot be copied to a dictionary or a transformed channel has an invalid type. + ValueError: If an input or transformed channel has an incompatible shape. + """ d = dict(data) self.randomize(None) if not self._do_transform: return d - + for key in self.key_iterator(d): d[key] = self.converter(d[key], randomize=False) return d diff --git a/tests/test_channel_wise.py b/tests/test_channel_wise.py index e3c12f698f..e3766f5b9b 100644 --- a/tests/test_channel_wise.py +++ b/tests/test_channel_wise.py @@ -7,44 +7,83 @@ from monai.utils import set_determinism +EXPECTED_SCALED = np.array([[0.0, 0.3333333], [0.6666667, 1.0]]) + + class TestChannelWise(unittest.TestCase): def test_channel_wise_deterministic(self): - # Test applying a deterministic transform channel-wise - data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) # shape (2, 2, 2) - - # ScaleIntensity applies to the whole input array independently + data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) + transform = ChannelWise(transform=ScaleIntensity()) out = transform(data) - - # Channel 0 scaled - np.testing.assert_allclose(out[0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) - # Channel 1 scaled - np.testing.assert_allclose(out[1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + + np.testing.assert_allclose(out[0], EXPECTED_SCALED, atol=1e-5) + np.testing.assert_allclose(out[1], EXPECTED_SCALED, atol=1e-5) self.assertEqual(out.shape, data.shape) + torch_data = torch.as_tensor(data) + torch_out = transform(torch_data) + + torch_expected = torch.as_tensor(EXPECTED_SCALED, dtype=torch_out.dtype) + self.assertTrue(torch.allclose(torch_out[0], torch_expected, atol=1e-5)) + self.assertTrue(torch.allclose(torch_out[1], torch_expected, atol=1e-5)) + self.assertEqual(torch_out.shape, torch_data.shape) + def test_rand_channel_wise(self): - # Test applying a randomized transform channel-wise - data = np.zeros((3, 4, 4)) - - set_determinism(seed=0) - # Apply random noise with high standard deviation to see the difference - transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) - out = transform(data) - - # All channels should have different noise values - self.assertFalse(np.allclose(out[0], out[1])) - self.assertFalse(np.allclose(out[1], out[2])) - self.assertFalse(np.allclose(out[0], out[2])) - - # Output shape should be exactly the same - self.assertEqual(out.shape, data.shape) + try: + set_determinism(seed=0) + + data = np.zeros((3, 4, 4)) + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + self.assertFalse(np.allclose(out[0], out[1])) + self.assertFalse(np.allclose(out[1], out[2])) + self.assertFalse(np.allclose(out[0], out[2])) + self.assertEqual(out.shape, data.shape) + + torch_data = torch.zeros((3, 4, 4)) + torch_transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + torch_out = torch_transform(torch_data) + + self.assertFalse(torch.allclose(torch_out[0], torch_out[1])) + self.assertFalse(torch.allclose(torch_out[1], torch_out[2])) + self.assertFalse(torch.allclose(torch_out[0], torch_out[2])) + self.assertEqual(torch_out.shape, torch_data.shape) + finally: + set_determinism(None) def test_prob_zero(self): - # Test when RandChannelWise prob is 0.0 data = np.zeros((2, 2, 2)) transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) out = transform(data) np.testing.assert_allclose(out, data) + torch_data = torch.zeros((2, 2, 2)) + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out, torch_data)) + + def test_squeezed_channel_result(self): + data = np.arange(8.0).reshape(2, 2, 2) + transform = ChannelWise(transform=lambda img: img[0]) + out = transform(data) + np.testing.assert_allclose(out, data) + self.assertEqual(out.shape, data.shape) + + torch_data = torch.as_tensor(data) + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out, torch_data)) + self.assertEqual(torch_out.shape, torch_data.shape) + + def test_invalid_channel_result_shape(self): + transform = ChannelWise(transform=lambda img: img[:0]) + + with self.assertRaises(ValueError): + transform(np.zeros((2, 2, 2))) + + with self.assertRaises(ValueError): + transform(torch.zeros((2, 2, 2))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_channel_wised.py b/tests/test_channel_wised.py index 8ed56ed7df..136606e573 100644 --- a/tests/test_channel_wised.py +++ b/tests/test_channel_wised.py @@ -7,44 +7,62 @@ from monai.utils import set_determinism +EXPECTED_SCALED = np.array([[0.0, 0.3333333], [0.6666667, 1.0]]) + + class TestChannelWised(unittest.TestCase): def test_channel_wise_deterministic(self): - # Test applying a deterministic transform channel-wise - data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} # shape (2, 2, 2) - - # ScaleIntensity applies to the whole input array independently + data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} + transform = ChannelWised(keys=["image"], transform=ScaleIntensity()) out = transform(data) - - # Channel 0 scaled - np.testing.assert_allclose(out["image"][0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) - # Channel 1 scaled - np.testing.assert_allclose(out["image"][1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + + np.testing.assert_allclose(out["image"][0], EXPECTED_SCALED, atol=1e-5) + np.testing.assert_allclose(out["image"][1], EXPECTED_SCALED, atol=1e-5) self.assertEqual(out["image"].shape, data["image"].shape) + torch_data = {"image": torch.as_tensor(data["image"])} + torch_out = transform(torch_data) + + torch_expected = torch.as_tensor(EXPECTED_SCALED, dtype=torch_out["image"].dtype) + self.assertTrue(torch.allclose(torch_out["image"][0], torch_expected, atol=1e-5)) + self.assertTrue(torch.allclose(torch_out["image"][1], torch_expected, atol=1e-5)) + self.assertEqual(torch_out["image"].shape, torch_data["image"].shape) + def test_rand_channel_wise(self): - # Test applying a randomized transform channel-wise - data = {"image": np.zeros((3, 4, 4))} - - set_determinism(seed=0) - # Apply random noise with high standard deviation to see the difference - transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) - out = transform(data) - - # All channels should have different noise values - self.assertFalse(np.allclose(out["image"][0], out["image"][1])) - self.assertFalse(np.allclose(out["image"][1], out["image"][2])) - self.assertFalse(np.allclose(out["image"][0], out["image"][2])) - - # Output shape should be exactly the same - self.assertEqual(out["image"].shape, data["image"].shape) + try: + set_determinism(seed=0) + + data = {"image": np.zeros((3, 4, 4))} + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + self.assertFalse(np.allclose(out["image"][0], out["image"][1])) + self.assertFalse(np.allclose(out["image"][1], out["image"][2])) + self.assertFalse(np.allclose(out["image"][0], out["image"][2])) + self.assertEqual(out["image"].shape, data["image"].shape) + + torch_data = {"image": torch.zeros((3, 4, 4))} + torch_transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + torch_out = torch_transform(torch_data) + + self.assertFalse(torch.allclose(torch_out["image"][0], torch_out["image"][1])) + self.assertFalse(torch.allclose(torch_out["image"][1], torch_out["image"][2])) + self.assertFalse(torch.allclose(torch_out["image"][0], torch_out["image"][2])) + self.assertEqual(torch_out["image"].shape, torch_data["image"].shape) + finally: + set_determinism(None) def test_prob_zero(self): - # Test when RandChannelWised prob is 0.0 data = {"image": np.zeros((2, 2, 2))} transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) out = transform(data) np.testing.assert_allclose(out["image"], data["image"]) + torch_data = {"image": torch.zeros((2, 2, 2))} + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out["image"], torch_data["image"])) + + if __name__ == "__main__": unittest.main()