Skip to content

Feat/channel wise transforms#8898

Open
ugbotueferhire wants to merge 2 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/channel-wise-transforms
Open

Feat/channel wise transforms#8898
ugbotueferhire wants to merge 2 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/channel-wise-transforms

Conversation

@ugbotueferhire

Copy link
Copy Markdown
Contributor

Fixes #8311.

Description

Adds new wrapper transforms ChannelWise, RandChannelWise, ChannelWised, and RandChannelWised to independently apply an array-based transform to each channel of an input array. This resolves issues surrounding applying data augmentations channel-wise, which is a common requirement for early fusion models where different 3D volumes or modalities are concatenated along the channel axis.

The ChannelWise transform ensures the inner transform receives slices with a singleton channel dimension to maintain expected shape invariants, and successfully maintains independent PRNG states for random augmentations per-channel.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai

coderabbitai Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

This PR introduces channel-wise transform wrappers that apply a given transform independently to each channel of an input, then concatenate results. The array-level ChannelWise handles per-channel application with shape validation; RandChannelWise wraps it with probability sampling and randomizable state. Dictionary variants ChannelWised and RandChannelWised apply these across specified keys. All classes are exported from monai.transforms and tested with deterministic, randomized, edge-case, and error scenarios on both NumPy and torch data.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.19% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: addition of channel-wise transform wrapper classes to MONAI.
Description check ✅ Passed The PR description addresses the linked issue #8311, explains the purpose and implementation, includes appropriate type-of-change selections, and confirms inline docstrings were updated.
Linked Issues check ✅ Passed The PR implements all coding requirements from #8311: four wrapper transforms (ChannelWise, RandChannelWise, ChannelWised, RandChannelWised) that apply transforms independently per channel with proper shape handling and PRNG state management.
Out of Scope Changes check ✅ Passed All changes directly support the PR objectives: new array and dictionary transforms, re-exports in init, and comprehensive test coverage for the new functionality.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Nitpick comments (2)
monai/transforms/utility/dictionary.py (1)

349-411: ⚡ Quick win

Complete docstrings on ChannelWised and RandChannelWised methods.

__call__ and set_random_state should include full Google-style Args/Returns/Raises to match project standards.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/dictionary.py` around lines 349 - 411, The
docstrings for ChannelWised.__call__, RandChannelWised.__call__, and
RandChannelWised.set_random_state are incomplete; update them to full
Google-style docstrings including Args (describe parameters like data, seed,
state, their types and behavior), Returns (describe returned dict[Hashable,
NdarrayOrTensor] and its contents), and Raises (document possible exceptions,
e.g., KeyError for missing keys when allow_missing_keys is False, TypeError for
invalid input types). Ensure ChannelWised and RandChannelWised class docstrings
mention their converter attributes and any randomness behavior, and for
RandChannelWised.set_random_state specify it returns self (RandChannelWised) and
that it delegates to converter.set_random_state when available.
monai/transforms/utility/array.py (1)

293-366: ⚡ Quick win

Add complete Google-style docstrings for new definitions.

The new class/method docstrings are minimal and omit full Args/Returns/Raises coverage for definitions like __call__ and set_random_state.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 293 - 366, Update the minimal
docstrings for ChannelWise and RandChannelWise to full Google-style docstrings:
for the classes add a short description and Args describing transform (callable)
and prob (float) where applicable; for ChannelWise.__call__ and
RandChannelWise.__call__ add Args (img: NdarrayOrTensor, randomize: bool for
RandChannelWise.__call__), Returns (NdarrayOrTensor) and Raises (e.g.,
ValueError if input shape invalid) sections with types and brief meanings; for
RandChannelWise.set_random_state add Args (seed, state), Returns (self /
RandChannelWise) and any raised exceptions; ensure wording matches existing type
hints (NdarrayOrTensor, np.random.RandomState) and keep examples/notes optional
but consistent with project Google-style docstring conventions, placing the
updated docstrings in the definitions of ChannelWise, ChannelWise.__call__,
RandChannelWise, RandChannelWise.set_random_state, and RandChannelWise.__call__.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/inferers/inferer.py`:
- Around line 936-954: Add a docstring to the _scheduler_step method explaining
that it wraps a Scheduler.step call and normalizes its output into a previous
sample Tensor; document parameters (scheduler: Scheduler, model_output:
torch.Tensor, timestep: int|torch.Tensor, sample: torch.Tensor, next_timestep:
int|torch.Tensor|None) and mention that step_kwargs may include return_dict
based on _scheduler_step_supports_kwarg, and that RFlowScheduler is handled with
a different call signature (RFlowScheduler.step includes next_timestep). State
the return type (torch.Tensor) and note that the method returns the previous
sample via _get_previous_sample_from_step_output.
- Around line 914-935: The method _get_posterior_variance is missing a
docstring; add a concise docstring above the function explaining what it
computes (the posterior variance used in the diffusion reverse step), describing
parameters (scheduler: Scheduler, timestep: int|torch.Tensor,
predicted_variance: torch.Tensor|None) and return type (torch.Tensor), and
briefly note behavior for each variance_type branch ("fixed_small",
"fixed_large", "learned", "learned_range") including how predicted_variance is
used; keep it short, follow existing project style (one-line summary + short
param/return descriptions).
- Around line 865-871: Add a Google-style docstring to the static method
_scheduler_step_supports_kwarg explaining the parameter introspection: describe
the parameters (scheduler: Scheduler, kwarg: str), the return value (bool
indicating whether scheduler.step accepts kwarg), and the exceptions handled
(TypeError and ValueError caught from inspect.signature). Mention that the
function uses inspect.signature(scheduler.step).parameters to check for the
kwarg and that inspected errors are swallowed (returning False).
- Around line 902-913: Add a docstring to the _get_posterior_mean function that
briefly states what the function computes (the posterior mean used in the
diffusion/scheduler step), describes the inputs (scheduler, timestep, x_0, x_t)
and their expected types/shapes, documents the mathematical formula being
implemented (coefficients for x_0 and x_t using scheduler.alphas,
scheduler.alphas_cumprod and scheduler.betas) and states the return type
(torch.Tensor); keep it concise, one-line summary plus parameter and return
sections, and mention edge-case behavior for timestep==0 where scheduler.one is
used.
- Around line 883-887: Add a docstring to the helper function
_get_scheduler_name(scheduler: Scheduler) that succinctly describes its purpose
(returns a human-readable name for a Scheduler instance), documents the
parameter (scheduler: Scheduler) and the return type (str), and notes the lookup
order (prefers scheduler._get_name() if present, otherwise uses
scheduler.__class__.__name__); update the function definition for
_get_scheduler_name to include this docstring directly above the implementation.
- Around line 873-881: Add a Google-style docstring to
_get_previous_sample_from_step_output that documents the parameter step_output
(types expected), the returned torch.Tensor, and the TypeError raised; also
change the TypeError message to include the actual type encountered (e.g., using
type(step_output)) so the error reports the unsupported type. Ensure the
docstring briefly explains the three supported shapes (tuple where [0] is prev
sample, Mapping with "prev_sample" key, or object attribute prev_sample) and
mentions the raised TypeError when none match.
- Around line 889-900: Add a descriptive docstring to the helper function
_get_scheduler_config_value(scheduler, name, default) explaining its purpose
(resolve a configuration value by first checking scheduler.config mapping or
attributes, then scheduler attributes, and returning default if not found), and
document the parameters (scheduler: Scheduler, name: str, default: Any) and
return type (Any) plus any behavior/edge-cases (e.g., handles Mapping config and
attribute lookup order). Keep it concise and follow project's docstring style
(short summary, params, returns).

In `@monai/transforms/utility/array.py`:
- Around line 315-321: The per-channel loop in the method calling self.transform
collects results and blindly concatenates them (torch.cat / np.concatenate),
which can corrupt layout if a wrapped transform drops the singleton channel
dimension; before appending each res, validate that its dimensionality and
leading channel size preserve the expected singleton channel (e.g., for
torch.Tensor res.ndim and res.shape[0]==1, for np.ndarray res.ndim and
res.shape[0]==1) and either raise a clear error or reintroduce the missing
channel axis (unsqueeze or np.expand_dims) so that all items in results are
consistent for concatenation; apply the same guard logic to the analogous loop
around lines 359-365.

In `@tests/test_channel_wise.py`:
- Around line 10-46: Add equivalent tests that exercise the torch backend by
re-running the same scenarios with torch.Tensor inputs: create torch tensors for
data in test_channel_wise_deterministic, test_rand_channel_wise, and
test_prob_zero and invoke ChannelWise, RandChannelWise with ScaleIntensity and
RandGaussianNoise respectively (use set_determinism(seed=0) to control
randomness for torch too). For assertions use torch.allclose (or convert outputs
to numpy with .cpu().numpy()) to check per-channel scaling and inequality of
random channels, and assert tensor shapes match; ensure the
RandChannelWise(prob=0.0) case returns an identical torch tensor. Reference
ChannelWise, RandChannelWise, ScaleIntensity, RandGaussianNoise, and
set_determinism to locate code to test.
- Around line 28-31: The test sets global determinism via
set_determinism(seed=0) but never restores it, so wrap the deterministic section
(where you instantiate RandChannelWise and apply it to data) in a try/finally
and in the finally call set_determinism(None) (or the API call that disables
determinism) to restore global state; specifically modify the block around
set_determinism(seed=0), transform = RandChannelWise(...), and out =
transform(data) to ensure set_determinism is undone after the test.

In `@tests/test_channel_wised.py`:
- Around line 10-46: Add tests that exercise ChannelWised and RandChannelWised
with torch.Tensor inputs (not only numpy arrays) so the torch code paths are
covered: for test_channel_wise_deterministic create data as torch.tensor with
the same values and call ChannelWised(keys=["image"],
transform=ScaleIntensity()) then assert the per-channel scaled results and shape
using torch.allclose; for test_rand_channel_wise use set_determinism(seed=0) and
pass a torch.zeros tensor into RandChannelWised(keys=["image"],
transform=RandGaussianNoise(prob=1.0, std=1.0)) and assert channels differ with
torch comparisons and shape equality; similarly add a torch variant of
test_prob_zero using RandChannelWised(..., prob=0.0) to assert output equals
input. Ensure you import torch and use ChannelWised, RandChannelWised,
ScaleIntensity, RandGaussianNoise, and set_determinism names from the module
under test.
- Around line 28-31: The test sets global determinism with
set_determinism(seed=0) but never restores it; after running the randomized
transform (the RandChannelWised/RandGaussianNoise call in
tests/test_channel_wised.py where transform(data) is executed) call
set_determinism(None) (or the library's provided reset/unset call) to restore
the global RNG/determinism state so other tests are not affected; place that
call immediately after out = transform(data).

---

Nitpick comments:
In `@monai/transforms/utility/array.py`:
- Around line 293-366: Update the minimal docstrings for ChannelWise and
RandChannelWise to full Google-style docstrings: for the classes add a short
description and Args describing transform (callable) and prob (float) where
applicable; for ChannelWise.__call__ and RandChannelWise.__call__ add Args (img:
NdarrayOrTensor, randomize: bool for RandChannelWise.__call__), Returns
(NdarrayOrTensor) and Raises (e.g., ValueError if input shape invalid) sections
with types and brief meanings; for RandChannelWise.set_random_state add Args
(seed, state), Returns (self / RandChannelWise) and any raised exceptions;
ensure wording matches existing type hints (NdarrayOrTensor,
np.random.RandomState) and keep examples/notes optional but consistent with
project Google-style docstring conventions, placing the updated docstrings in
the definitions of ChannelWise, ChannelWise.__call__, RandChannelWise,
RandChannelWise.set_random_state, and RandChannelWise.__call__.

In `@monai/transforms/utility/dictionary.py`:
- Around line 349-411: The docstrings for ChannelWised.__call__,
RandChannelWised.__call__, and RandChannelWised.set_random_state are incomplete;
update them to full Google-style docstrings including Args (describe parameters
like data, seed, state, their types and behavior), Returns (describe returned
dict[Hashable, NdarrayOrTensor] and its contents), and Raises (document possible
exceptions, e.g., KeyError for missing keys when allow_missing_keys is False,
TypeError for invalid input types). Ensure ChannelWised and RandChannelWised
class docstrings mention their converter attributes and any randomness behavior,
and for RandChannelWised.set_random_state specify it returns self
(RandChannelWised) and that it delegates to converter.set_random_state when
available.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 63a24f4b-29bb-4185-ac59-093521eb8abb

📥 Commits

Reviewing files that changed from the base of the PR and between 2a7d0cf and dfe632f.

📒 Files selected for processing (8)
  • monai/inferers/inferer.py
  • monai/transforms/__init__.py
  • monai/transforms/utility/array.py
  • monai/transforms/utility/dictionary.py
  • tests/inferers/test_diffusion_inferer.py
  • tests/inferers/test_latent_diffusion_inferer.py
  • tests/test_channel_wise.py
  • tests/test_channel_wised.py

Comment thread monai/inferers/inferer.py Outdated
Comment thread monai/inferers/inferer.py Outdated
Comment thread monai/inferers/inferer.py Outdated
Comment thread monai/inferers/inferer.py Outdated
Comment thread monai/inferers/inferer.py Outdated
Comment thread monai/transforms/utility/array.py Outdated
Comment thread tests/test_channel_wise.py
Comment thread tests/test_channel_wise.py Outdated
Comment thread tests/test_channel_wised.py
Comment thread tests/test_channel_wised.py Outdated
@ugbotueferhire

Copy link
Copy Markdown
Contributor Author

Hi @ericspod please can you check this out?

@aymuos15

aymuos15 commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

ChannelWise(ScaleIntensity()) duplicates the existing channel_wise=True flag already on most intensity transforms (ScaleIntensity, NormalizeIntensity, RandShiftIntensity, etc.)

I think that should be the base (its also more performant than the loop approach here)

@ugbotueferhire ugbotueferhire force-pushed the feat/channel-wise-transforms branch from dfe632f to 2781305 Compare June 7, 2026 21:30

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
monai/transforms/utility/array.py (2)

430-434: 💤 Low value

Redundant docstring in __init__.

Same issue as ChannelWise: the __init__ docstring duplicates the class-level Args. Consider removing for brevity.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 430 - 434, Remove the
redundant __init__ docstring that duplicates the class-level Args in the class
(same pattern as ChannelWise); locate the __init__ method in
monai/transforms/utility/array.py and delete the repeated Args block inside
__init__ while keeping the class-level docstring intact so parameter
documentation isn't duplicated.

Source: Coding guidelines


305-308: 💤 Low value

Redundant docstring in __init__.

The __init__ docstring duplicates the class-level Args section. Google-style prefers documenting constructor parameters at the class level when they're simple pass-through assignments. Either remove this docstring or add implementation details that aren't in the class docstring.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 305 - 308, The __init__
method currently contains a redundant Google-style docstring duplicating the
class-level Args; remove the duplicate docstring literal inside the __init__
method (the constructor named __init__) so parameter docs remain only at the
class level, or alternatively replace that constructor docstring with additional
implementation-specific details that are not already present in the class-level
docstring; update only the __init__ block in monai/transforms/utility/array.py
to either delete the docstring or enrich it with non-duplicative information.

Source: Coding guidelines

tests/test_channel_wise.py (1)

13-85: ⚡ Quick win

Add Google-style docstrings to class and test methods.

Both the TestChannelWise class and all five test methods lack docstrings. Briefly describe what each test validates.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_channel_wise.py` around lines 13 - 85, Add Google-style docstrings
to the TestChannelWise class and each test method
(test_channel_wise_deterministic, test_rand_channel_wise, test_prob_zero,
test_squeezed_channel_result, test_invalid_channel_result_shape): for the class
give a one-line summary of the test suite; for each method include a short
description of what is being validated, list key parameters/fixtures (e.g.,
input arrays like data / torch_data), expected outcome (what is asserted), and
any raised exceptions (e.g., ValueError in test_invalid_channel_result_shape).
Keep docstrings concise and follow Google-style sections (Args, Returns if any,
Raises where applicable).

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@monai/transforms/utility/array.py`:
- Around line 430-434: Remove the redundant __init__ docstring that duplicates
the class-level Args in the class (same pattern as ChannelWise); locate the
__init__ method in monai/transforms/utility/array.py and delete the repeated
Args block inside __init__ while keeping the class-level docstring intact so
parameter documentation isn't duplicated.
- Around line 305-308: The __init__ method currently contains a redundant
Google-style docstring duplicating the class-level Args; remove the duplicate
docstring literal inside the __init__ method (the constructor named __init__) so
parameter docs remain only at the class level, or alternatively replace that
constructor docstring with additional implementation-specific details that are
not already present in the class-level docstring; update only the __init__ block
in monai/transforms/utility/array.py to either delete the docstring or enrich it
with non-duplicative information.

In `@tests/test_channel_wise.py`:
- Around line 13-85: Add Google-style docstrings to the TestChannelWise class
and each test method (test_channel_wise_deterministic, test_rand_channel_wise,
test_prob_zero, test_squeezed_channel_result,
test_invalid_channel_result_shape): for the class give a one-line summary of the
test suite; for each method include a short description of what is being
validated, list key parameters/fixtures (e.g., input arrays like data /
torch_data), expected outcome (what is asserted), and any raised exceptions
(e.g., ValueError in test_invalid_channel_result_shape). Keep docstrings concise
and follow Google-style sections (Args, Returns if any, Raises where
applicable).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 778a6c24-29e5-4032-8493-7393a1a13c60

📥 Commits

Reviewing files that changed from the base of the PR and between dfe632f and 2781305.

📒 Files selected for processing (5)
  • monai/transforms/__init__.py
  • monai/transforms/utility/array.py
  • monai/transforms/utility/dictionary.py
  • tests/test_channel_wise.py
  • tests/test_channel_wised.py
✅ Files skipped from review due to trivial changes (1)
  • monai/transforms/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_channel_wised.py
  • monai/transforms/utility/dictionary.py

@ugbotueferhire

Copy link
Copy Markdown
Contributor Author

ChannelWise(ScaleIntensity()) duplicates the existing channel_wise=True flag already on most intensity transforms (ScaleIntensity, NormalizeIntensity, RandShiftIntensity, etc.)

I think that should be the base (its also more performant than the loop approach here)

Thanks, that’s a fair point. For transforms that already support channel_wise=True, that should remain the preferred path since it avoids the Python-level loop and is more efficient.

The intent of this wrapper is to cover the more general case: arbitrary array transforms, custom callables, or transforms that don’t expose a channel_wise option but still need to be applied independently per channel. ScaleIntensity is only used in the tests as a simple deterministic transform, not as the primary motivating use case.

I can add wording to the docstring to make that distinction explicit: use native channel_wise=True when available; use ChannelWise/RandChannelWise as a generic fallback wrapper.

@aymuos15

aymuos15 commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Why would we want a fallback there? And is it not better to just expand what already exists?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implementing Channel-Wise Transforms

2 participants