Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ChannelWise,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
CuCIM,
Expand All @@ -536,6 +537,7 @@
RandIdentity,
RandImageFilter,
RandLambda,
RandChannelWise,
RandTorchIO,
RandTorchVision,
RemoveRepeatedChannel,
Expand Down Expand Up @@ -568,6 +570,9 @@
CastToTyped,
CastToTypeD,
CastToTypeDict,
ChannelWised,
ChannelWiseD,
ChannelWiseDict,
ClassesToIndicesd,
ClassesToIndicesD,
ClassesToIndicesDict,
Expand Down Expand Up @@ -631,6 +636,9 @@
RandLambdad,
RandLambdaD,
RandLambdaDict,
RandChannelWised,
RandChannelWiseD,
RandChannelWiseDict,
RandTorchIOd,
RandTorchIOD,
RandTorchIODict,
Expand Down
187 changes: 187 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
"EnsureType",
"RepeatChannel",
"RemoveRepeatedChannel",
"ChannelWise",
"RandChannelWise",
"SplitDim",
"CastToType",
"ToTensor",
Expand Down Expand Up @@ -288,6 +290,191 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return out


class ChannelWise(Transform):
"""
Apply a transform to each channel of a channel-first array independently.

Args:
transform: Callable transform to apply to each channel. The transform receives
a single-channel array or tensor with a singleton leading channel dimension.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, transform: Callable) -> None:
"""
Args:
transform: Callable transform to apply to each channel.
"""
self.transform = transform

@staticmethod
def _normalize_channel_result(
result: NdarrayOrTensor, expected_ndim: int, channel_index: int
) -> NdarrayOrTensor:
"""
Ensure a per-channel transform result has a singleton channel dimension.

Args:
result: Output from applying the wrapped transform to one input channel.
expected_ndim: Expected dimensionality of the channel-first output.
channel_index: Index of the input channel being transformed.

Returns:
The transform result with a singleton leading channel dimension.

Raises:
ValueError: If the result cannot be concatenated along the channel dimension.
TypeError: If the result is not a NumPy array or torch tensor.
"""
if isinstance(result, torch.Tensor):
if result.ndim == expected_ndim - 1:
return result.unsqueeze(0)
if result.ndim == expected_ndim and result.shape[0] == 1:
return result
elif isinstance(result, np.ndarray):
if result.ndim == expected_ndim - 1:
return np.expand_dims(result, axis=0)
if result.ndim == expected_ndim and result.shape[0] == 1:
return result
else:
raise TypeError(
f"Channel {channel_index} transform output must be a NumPy array or torch tensor, "
f"got {type(result).__name__}."
)

raise ValueError(
f"Channel {channel_index} transform output must preserve a singleton leading channel "
f"dimension or return a squeezed channel with {expected_ndim - 1} dimensions, "
f"got shape {tuple(result.shape)}."
)

@staticmethod
def _concatenate_channel_results(results: list[NdarrayOrTensor]) -> NdarrayOrTensor:
"""
Concatenate normalized per-channel transform results.

Args:
results: Sequence of normalized per-channel NumPy arrays or torch tensors.

Returns:
A NumPy array or torch tensor concatenated along the leading channel dimension.

Raises:
TypeError: If results contain mixed array and tensor types.
"""
if all(isinstance(result, torch.Tensor) for result in results):
return torch.cat(results, dim=0)
if all(isinstance(result, np.ndarray) for result in results):
return np.concatenate(results, axis=0)
raise TypeError("All channel-wise transform outputs must have the same array or tensor type.")

@classmethod
def _apply_channel_wise(cls, img: NdarrayOrTensor, transform: Callable) -> NdarrayOrTensor:
"""
Apply a callable independently to each channel and concatenate the outputs.

Args:
img: Channel-first NumPy array or torch tensor to transform.
transform: Callable transform to apply to each channel.

Returns:
A NumPy array or torch tensor containing the transformed channels.

Raises:
ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape.
TypeError: If a transformed channel has an invalid type or mixed output types are produced.
"""
if len(img.shape) == 0:
raise ValueError("Image must have a channel dimension.")
if img.shape[0] == 0:
return img

results = [
cls._normalize_channel_result(transform(img[[i], ...]), expected_ndim=len(img.shape), channel_index=i)
for i in range(img.shape[0])
]
return cls._concatenate_channel_results(results)

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the wrapped transform independently to each channel.

Args:
img: Channel-first NumPy array or torch tensor to transform.

Returns:
A NumPy array or torch tensor containing the transformed channels.

Raises:
ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape.
TypeError: If a transformed channel is not a NumPy array or torch tensor.
"""
return self._apply_channel_wise(img, self.transform)


class RandChannelWise(RandomizableTransform):
"""
Randomizable version of :py:class:`monai.transforms.ChannelWise`, the input
`transform` will be applied independently to each channel.

Args:
transform: Callable transform to apply to each channel. If it defines
``set_random_state``, the state is synchronized by this wrapper.
prob: Probability of applying the transform to the entire image.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, transform: Callable, prob: float = 1.0) -> None:
"""
Args:
transform: Callable transform to apply to each channel.
prob: Probability of applying the transform to the entire image.
"""
RandomizableTransform.__init__(self, prob)
self.transform = transform

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise:
"""
Set the random state for this transform and its wrapped transform.

Args:
seed: Seed to use for a new NumPy random state.
state: Existing NumPy random state to use. If provided, it takes precedence over ``seed``.

Returns:
This ``RandChannelWise`` instance.
"""
super().set_random_state(seed, state)
if hasattr(self.transform, "set_random_state"):
self.transform.set_random_state(seed, state)
return self

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the wrapped transform independently to each channel when sampled.

Args:
img: Channel-first NumPy array or torch tensor to transform.
randomize: Whether to sample this wrapper's random state before applying the transform.

Returns:
A NumPy array or torch tensor containing the transformed channels, or the input unchanged
when this transform is not sampled.

Raises:
ValueError: If `img` has no channel dimension, or a transformed channel has an incompatible shape.
TypeError: If a transformed channel is not a NumPy array or torch tensor.
"""
if randomize:
self.randomize(None)
if not self._do_transform:
return img

return ChannelWise._apply_channel_wise(img, self.transform)


class SplitDim(Transform, MultiSampleTrait):
"""
Given an image of size X along a certain dimension, return a list of length X containing
Expand Down
Loading
Loading