diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c7ac4b77e6..f9be2e5b45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,6 +518,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -536,6 +537,7 @@ RandIdentity, RandImageFilter, RandLambda, + RandChannelWise, RandTorchIO, RandTorchVision, RemoveRepeatedChannel, @@ -568,6 +570,9 @@ CastToTyped, CastToTypeD, CastToTypeDict, + ChannelWised, + ChannelWiseD, + ChannelWiseDict, ClassesToIndicesd, ClassesToIndicesD, ClassesToIndicesDict, @@ -631,6 +636,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandChannelWised, + RandChannelWiseD, + RandChannelWiseDict, RandTorchIOd, RandTorchIOD, RandTorchIODict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ed4b149e6b..a53a5c0eea 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -81,6 +81,8 @@ "EnsureType", "RepeatChannel", "RemoveRepeatedChannel", + "ChannelWise", + "RandChannelWise", "SplitDim", "CastToType", "ToTensor", @@ -288,6 +290,191 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return out +class ChannelWise(Transform): + """ + Apply a transform to each channel of a channel-first array independently. + + Args: + transform: Callable transform to apply to each channel. The transform receives + a single-channel array or tensor with a singleton leading channel dimension. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable) -> None: + """ + Args: + transform: Callable transform to apply to each channel. + """ + self.transform = transform + + @staticmethod + def _normalize_channel_result( + result: NdarrayOrTensor, expected_ndim: int, channel_index: int + ) -> NdarrayOrTensor: + """ + Ensure a per-channel transform result has a singleton channel dimension. + + Args: + result: Output from applying the wrapped transform to one input channel. + expected_ndim: Expected dimensionality of the channel-first output. + channel_index: Index of the input channel being transformed. + + Returns: + The transform result with a singleton leading channel dimension. + + Raises: + ValueError: If the result cannot be concatenated along the channel dimension. + TypeError: If the result is not a NumPy array or torch tensor. + """ + if isinstance(result, torch.Tensor): + if result.ndim == expected_ndim - 1: + return result.unsqueeze(0) + if result.ndim == expected_ndim and result.shape[0] == 1: + return result + elif isinstance(result, np.ndarray): + if result.ndim == expected_ndim - 1: + return np.expand_dims(result, axis=0) + if result.ndim == expected_ndim and result.shape[0] == 1: + return result + else: + raise TypeError( + f"Channel {channel_index} transform output must be a NumPy array or torch tensor, " + f"got {type(result).__name__}." + ) + + raise ValueError( + f"Channel {channel_index} transform output must preserve a singleton leading channel " + f"dimension or return a squeezed channel with {expected_ndim - 1} dimensions, " + f"got shape {tuple(result.shape)}." + ) + + @staticmethod + def _concatenate_channel_results(results: list[NdarrayOrTensor]) -> NdarrayOrTensor: + """ + Concatenate normalized per-channel transform results. + + Args: + results: Sequence of normalized per-channel NumPy arrays or torch tensors. + + Returns: + A NumPy array or torch tensor concatenated along the leading channel dimension. + + Raises: + TypeError: If results contain mixed array and tensor types. + """ + if all(isinstance(result, torch.Tensor) for result in results): + return torch.cat(results, dim=0) + if all(isinstance(result, np.ndarray) for result in results): + return np.concatenate(results, axis=0) + raise TypeError("All channel-wise transform outputs must have the same array or tensor type.") + + @classmethod + def _apply_channel_wise(cls, img: NdarrayOrTensor, transform: Callable) -> NdarrayOrTensor: + """ + Apply a callable independently to each channel and concatenate the outputs. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + transform: Callable transform to apply to each channel. + + Returns: + A NumPy array or torch tensor containing the transformed channels. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel has an invalid type or mixed output types are produced. + """ + if len(img.shape) == 0: + raise ValueError("Image must have a channel dimension.") + if img.shape[0] == 0: + return img + + results = [ + cls._normalize_channel_result(transform(img[[i], ...]), expected_ndim=len(img.shape), channel_index=i) + for i in range(img.shape[0]) + ] + return cls._concatenate_channel_results(results) + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the wrapped transform independently to each channel. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + + Returns: + A NumPy array or torch tensor containing the transformed channels. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel is not a NumPy array or torch tensor. + """ + return self._apply_channel_wise(img, self.transform) + + +class RandChannelWise(RandomizableTransform): + """ + Randomizable version of :py:class:`monai.transforms.ChannelWise`, the input + `transform` will be applied independently to each channel. + + Args: + transform: Callable transform to apply to each channel. If it defines + ``set_random_state``, the state is synchronized by this wrapper. + prob: Probability of applying the transform to the entire image. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable, prob: float = 1.0) -> None: + """ + Args: + transform: Callable transform to apply to each channel. + prob: Probability of applying the transform to the entire image. + """ + RandomizableTransform.__init__(self, prob) + self.transform = transform + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise: + """ + Set the random state for this transform and its wrapped transform. + + Args: + seed: Seed to use for a new NumPy random state. + state: Existing NumPy random state to use. If provided, it takes precedence over ``seed``. + + Returns: + This ``RandChannelWise`` instance. + """ + super().set_random_state(seed, state) + if hasattr(self.transform, "set_random_state"): + self.transform.set_random_state(seed, state) + return self + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the wrapped transform independently to each channel when sampled. + + Args: + img: Channel-first NumPy array or torch tensor to transform. + randomize: Whether to sample this wrapper's random state before applying the transform. + + Returns: + A NumPy array or torch tensor containing the transformed channels, or the input unchanged + when this transform is not sampled. + + Raises: + ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape. + TypeError: If a transformed channel is not a NumPy array or torch tensor. + """ + if randomize: + self.randomize(None) + if not self._do_transform: + return img + + return ChannelWise._apply_channel_wise(img, self.transform) + + class SplitDim(Transform, MultiSampleTrait): """ Given an image of size X along a certain dimension, return a list of length X containing diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd24a3880..5c7b7b7b6c 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -38,6 +38,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -52,6 +53,7 @@ LabelToMask, Lambda, MapLabelValue, + RandChannelWise, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -88,6 +90,9 @@ "ConcatItemsD", "ConcatItemsDict", "ConcatItemsd", + "ChannelWiseD", + "ChannelWiseDict", + "ChannelWised", "ConvertToMultiChannelBasedOnBratsClassesD", "ConvertToMultiChannelBasedOnBratsClassesDict", "ConvertToMultiChannelBasedOnBratsClassesd", @@ -131,6 +136,9 @@ "FlattenSubKeysd", "FlattenSubKeysD", "FlattenSubKeysDict", + "RandChannelWiseD", + "RandChannelWiseDict", + "RandChannelWised", "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", @@ -338,6 +346,127 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ChannelWised(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ChannelWise`. + + This transform stores a ``ChannelWise`` converter and applies it to each selected key. + + Args: + keys: Keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform`. + transform: Callable transform to apply independently to each channel. + allow_missing_keys: Don't raise an exception if a key is missing. + """ + + backend = ChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.converter = ChannelWise(transform=transform) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Apply the channel-wise converter to each selected item. + + Args: + data: Mapping containing channel-first arrays or tensors for the configured keys. + + Returns: + A dictionary with transformed values for the configured keys and unchanged values for other keys. + + Raises: + KeyError: If a configured key is missing and ``allow_missing_keys`` is ``False``. + TypeError: If ``data`` cannot be copied to a dictionary or a transformed channel has an invalid type. + ValueError: If an input or transformed channel has an incompatible shape. + """ + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class RandChannelWised(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandChannelWise`. + + This transform samples once per call and, when selected, delegates per-channel + randomized processing to its ``RandChannelWise`` converter. + + Args: + keys: Keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform`. + transform: Callable transform to apply independently to each channel. + prob: Probability of applying the transform to the selected items. + allow_missing_keys: Don't raise an exception if a key is missing. + """ + + backend = RandChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, prob: float = 1.0, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + prob: probability of applying the transform to the entire image. + allow_missing_keys: don't raise exception if key is missing. + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.converter = RandChannelWise(transform=transform, prob=1.0) + + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandChannelWised: + """ + Set the random state for this transform and its converter. + + Args: + seed: Seed to use for a new NumPy random state. + state: Existing NumPy random state to use. If provided, it takes precedence over ``seed``. + + Returns: + This ``RandChannelWised`` instance. + """ + super().set_random_state(seed, state) + if hasattr(self.converter, "set_random_state"): + self.converter.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Apply the randomized channel-wise converter to each selected item when sampled. + + Args: + data: Mapping containing channel-first arrays or tensors for the configured keys. + + Returns: + A dictionary with transformed values for the configured keys when sampled, or a shallow + copy of ``data`` when not sampled. + + Raises: + KeyError: If a configured key is missing and ``allow_missing_keys`` is ``False``. + TypeError: If ``data`` cannot be copied to a dictionary or a transformed channel has an invalid type. + ValueError: If an input or transformed channel has an incompatible shape. + """ + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + for key in self.key_iterator(d): + d[key] = self.converter(d[key], randomize=False) + return d + + class SplitDimd(MapTransform, MultiSampleTrait): backend = SplitDim.backend @@ -2032,6 +2161,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N AsChannelLastD = AsChannelLastDict = AsChannelLastd EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld +ChannelWiseD = ChannelWiseDict = ChannelWised +RandChannelWiseD = RandChannelWiseDict = RandChannelWised RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitDimD = SplitDimDict = SplitDimd CastToTypeD = CastToTypeDict = CastToTyped diff --git a/tests/test_channel_wise.py b/tests/test_channel_wise.py new file mode 100644 index 0000000000..e3766f5b9b --- /dev/null +++ b/tests/test_channel_wise.py @@ -0,0 +1,89 @@ +import unittest + +import numpy as np +import torch + +from monai.transforms import ChannelWise, RandChannelWise, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +EXPECTED_SCALED = np.array([[0.0, 0.3333333], [0.6666667, 1.0]]) + + +class TestChannelWise(unittest.TestCase): + def test_channel_wise_deterministic(self): + data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) + + transform = ChannelWise(transform=ScaleIntensity()) + out = transform(data) + + np.testing.assert_allclose(out[0], EXPECTED_SCALED, atol=1e-5) + np.testing.assert_allclose(out[1], EXPECTED_SCALED, atol=1e-5) + self.assertEqual(out.shape, data.shape) + + torch_data = torch.as_tensor(data) + torch_out = transform(torch_data) + + torch_expected = torch.as_tensor(EXPECTED_SCALED, dtype=torch_out.dtype) + self.assertTrue(torch.allclose(torch_out[0], torch_expected, atol=1e-5)) + self.assertTrue(torch.allclose(torch_out[1], torch_expected, atol=1e-5)) + self.assertEqual(torch_out.shape, torch_data.shape) + + def test_rand_channel_wise(self): + try: + set_determinism(seed=0) + + data = np.zeros((3, 4, 4)) + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + self.assertFalse(np.allclose(out[0], out[1])) + self.assertFalse(np.allclose(out[1], out[2])) + self.assertFalse(np.allclose(out[0], out[2])) + self.assertEqual(out.shape, data.shape) + + torch_data = torch.zeros((3, 4, 4)) + torch_transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + torch_out = torch_transform(torch_data) + + self.assertFalse(torch.allclose(torch_out[0], torch_out[1])) + self.assertFalse(torch.allclose(torch_out[1], torch_out[2])) + self.assertFalse(torch.allclose(torch_out[0], torch_out[2])) + self.assertEqual(torch_out.shape, torch_data.shape) + finally: + set_determinism(None) + + def test_prob_zero(self): + data = np.zeros((2, 2, 2)) + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out, data) + + torch_data = torch.zeros((2, 2, 2)) + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out, torch_data)) + + def test_squeezed_channel_result(self): + data = np.arange(8.0).reshape(2, 2, 2) + transform = ChannelWise(transform=lambda img: img[0]) + out = transform(data) + np.testing.assert_allclose(out, data) + self.assertEqual(out.shape, data.shape) + + torch_data = torch.as_tensor(data) + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out, torch_data)) + self.assertEqual(torch_out.shape, torch_data.shape) + + def test_invalid_channel_result_shape(self): + transform = ChannelWise(transform=lambda img: img[:0]) + + with self.assertRaises(ValueError): + transform(np.zeros((2, 2, 2))) + + with self.assertRaises(ValueError): + transform(torch.zeros((2, 2, 2))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_channel_wised.py b/tests/test_channel_wised.py new file mode 100644 index 0000000000..136606e573 --- /dev/null +++ b/tests/test_channel_wised.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np +import torch + +from monai.transforms import ChannelWised, RandChannelWised, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +EXPECTED_SCALED = np.array([[0.0, 0.3333333], [0.6666667, 1.0]]) + + +class TestChannelWised(unittest.TestCase): + def test_channel_wise_deterministic(self): + data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} + + transform = ChannelWised(keys=["image"], transform=ScaleIntensity()) + out = transform(data) + + np.testing.assert_allclose(out["image"][0], EXPECTED_SCALED, atol=1e-5) + np.testing.assert_allclose(out["image"][1], EXPECTED_SCALED, atol=1e-5) + self.assertEqual(out["image"].shape, data["image"].shape) + + torch_data = {"image": torch.as_tensor(data["image"])} + torch_out = transform(torch_data) + + torch_expected = torch.as_tensor(EXPECTED_SCALED, dtype=torch_out["image"].dtype) + self.assertTrue(torch.allclose(torch_out["image"][0], torch_expected, atol=1e-5)) + self.assertTrue(torch.allclose(torch_out["image"][1], torch_expected, atol=1e-5)) + self.assertEqual(torch_out["image"].shape, torch_data["image"].shape) + + def test_rand_channel_wise(self): + try: + set_determinism(seed=0) + + data = {"image": np.zeros((3, 4, 4))} + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + self.assertFalse(np.allclose(out["image"][0], out["image"][1])) + self.assertFalse(np.allclose(out["image"][1], out["image"][2])) + self.assertFalse(np.allclose(out["image"][0], out["image"][2])) + self.assertEqual(out["image"].shape, data["image"].shape) + + torch_data = {"image": torch.zeros((3, 4, 4))} + torch_transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + torch_out = torch_transform(torch_data) + + self.assertFalse(torch.allclose(torch_out["image"][0], torch_out["image"][1])) + self.assertFalse(torch.allclose(torch_out["image"][1], torch_out["image"][2])) + self.assertFalse(torch.allclose(torch_out["image"][0], torch_out["image"][2])) + self.assertEqual(torch_out["image"].shape, torch_data["image"].shape) + finally: + set_determinism(None) + + def test_prob_zero(self): + data = {"image": np.zeros((2, 2, 2))} + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out["image"], data["image"]) + + torch_data = {"image": torch.zeros((2, 2, 2))} + torch_out = transform(torch_data) + self.assertTrue(torch.allclose(torch_out["image"], torch_data["image"])) + + +if __name__ == "__main__": + unittest.main()