Skip to content

Add MAISI modality conditioning hooks#8903

Open
ugbotueferhire wants to merge 2 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/maisi-modality-conditioning
Open

Add MAISI modality conditioning hooks#8903
ugbotueferhire wants to merge 2 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/maisi-modality-conditioning

Conversation

@ugbotueferhire

Copy link
Copy Markdown
Contributor

Fixes #8170.

Description

This PR adds optional modality conditioning hooks to MAISI networks. DiffusionModelUNetMaisi and ControlNetMaisi now support include_modality_input with a scalar modality_tensor of shape (N, 1), enabling CT/MRI/multimodal training configurations while preserving current default behavior.

The change also adds validation for missing or incorrectly shaped modality and auxiliary conditioning tensors, plus targeted MAISI network tests covering 2D/3D shape behavior and missing-input errors.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
@coderabbitai

coderabbitai Bot commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bc6752ea-f0f5-44e3-9f32-fa0a4fe3a7a3

📥 Commits

Reviewing files that changed from the base of the PR and between c60d5f2 and 0840b27.

📒 Files selected for processing (1)
  • tests/apps/maisi/networks/test_controlnet_maisi.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/apps/maisi/networks/test_controlnet_maisi.py

📝 Walkthrough

Walkthrough

This PR expands MAISI network models to support optional modality conditioning. ControlNetMaisi was refactored to perform initialization directly and now accepts include_modality_input flag with modality_tensor in forward calls. DiffusionModelUNetMaisi gains parallel modality support with validation, embedding creation, and dimension adjustment. Both models concatenate modality embeddings into the conditioning stream when enabled. Tests verify correct shapes with modality input and error handling when modality tensors are missing but expected.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.35% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely summarizes the main feature addition—modality conditioning hooks for MAISI networks.
Description check ✅ Passed Description covers objectives, change scope, and implementation details. All critical sections are present and informative.
Linked Issues check ✅ Passed Code changes directly implement modality input support for DiffusionModelUNetMaisi and ControlNetMaisi, fulfilling #8170's request to expand MAISI for multimodal synthetic data generation.
Out of Scope Changes check ✅ Passed All changes are scoped to modality conditioning support: network refactoring, validation logic, and targeted test coverage. No extraneous modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py (1)

511-542: ⚡ Quick win

Add negative tests for malformed auxiliary tensor shapes.

Current additions test “missing tensor” only. Please also cover wrong-shape cases (e.g., modality_tensor (1, 2), spacing_tensor (1, 2)) to exercise the new shape-validation branch.

As per coding guidelines, “Ensure new or modified definitions will be covered by existing or new unit tests.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py` around lines
511 - 542, Extend the unit tests to assert the shape-validation branch by adding
negative tests that pass malformed auxiliary tensors to DiffusionModelUNetMaisi:
in addition to test_modality_input_missing, add a test that constructs the same
net (include_modality_input=True) and calls net.forward with a malformed
modality_tensor of shape (1,2) (e.g., torch.rand((1,2))) inside eval_mode(net)
and assert it raises ValueError with the "modality_tensor should be provided"
(or the specific shape-related message); likewise, extend or add a test next to
test_additional_input_missing that builds the net with
include_spacing_input=True and calls net.forward with a malformed spacing_tensor
of shape (1,2) and asserts a ValueError with the "spacing_tensor should be
provided" (or specific shape message). Ensure you use the same call pattern
(net.forward(input, timestep, modality_tensor=..., spacing_tensor=...)) and the
same skipUnless(has_einops) decorator so these new tests exercise the
shape-validation branch in DiffusionModelUNetMaisi.

Source: Coding guidelines

tests/apps/maisi/networks/test_controlnet_maisi.py (1)

182-203: 💤 Low value

Consider adding a test for invalid modality_tensor shape.

The _validate_input_tensor method validates shape (N, 1) and raises ValueError for wrong dimensions. A test with modality_tensor=torch.ones((1, 2)) would cover that branch.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/apps/maisi/networks/test_controlnet_maisi.py` around lines 182 - 203,
Add a test case to verify that ControlNetMaisi._validate_input_tensor rejects
modality tensors with invalid shape: update or add a test in
test_modality_input_missing (or a new test function) that calls net.forward with
modality_tensor=torch.ones((1, 2)) and asserts a ValueError is raised (matching
the existing message or appropriate one); reference the ControlNetMaisi class
and its _validate_input_tensor method so the test covers the branch where
modality_tensor has shape (N, 2) instead of (N, 1).
monai/apps/generation/maisi/networks/controlnet_maisi.py (3)

271-297: ⚡ Quick win

forward docstring should document modality_tensor parameter.

The method signature now includes modality_tensor, but the docstring (inherited or implicit) doesn't describe it. Should match the pattern used in DiffusionModelUNetMaisi.forward which documents modality_tensor: Tensor representing modality of shape (N, 1).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 271 -
297, The forward method in controlnet_maisi.py is missing documentation for the
new modality_tensor parameter; update the forward docstring for
ControlNetMaisi.forward (or the class's forward) to include a line describing
modality_tensor as "Tensor representing modality of shape (N, 1)" (matching
DiffusionModelUNetMaisi.forward) and mention its expected dtype and optional
nature (Tensor | None) and role in conditioning; place this description
alongside the existing parameter docs for x, timesteps, controlnet_cond,
context, and class_labels.

Source: Coding guidelines


253-269: ⚡ Quick win

Missing docstrings for new helper methods.

_create_embedding_module, _validate_input_tensor, and _get_input_embeddings lack docstrings describing their arguments, return values, and raised exceptions.

As per coding guidelines: "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

📝 Example docstring for _validate_input_tensor
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
    """Validate and cast an optional input tensor.

    Args:
        tensor: Input tensor to validate.
        tensor_name: Name of the tensor for error messages.
        include_flag_name: Name of the flag that enables this input.
        expected_last_dim: Expected size of the last dimension.
        emb: Reference tensor for dtype casting.

    Returns:
        Tensor cast to the same dtype as emb.

    Raises:
        ValueError: If tensor is None or has incorrect shape.
    """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 253 -
269, Add Google-style docstrings to the three new helper methods
_create_embedding_module, _validate_input_tensor, and _get_input_embeddings: for
each method document all arguments (with types/meaning), the return value, and
any exceptions raised (e.g., ValueError in _validate_input_tensor) using the
Args/Returns/Raises sections; ensure _validate_input_tensor explains tensor,
tensor_name, include_flag_name, expected_last_dim, emb, the cast to emb.dtype,
and the ValueError conditions, _create_embedding_module documents input_dim and
embed_dim and the returned nn.Sequential module, and _get_input_embeddings
documents modality behavior, when include_modality_input is used, what it
returns and that it concatenates modality embeddings to emb.

Source: Coding guidelines


161-171: zero_module usage is functionally equivalent for these blocks (no silent behavior difference).

  • In Convolution(..., conv_only=True), the wrapper only contains the submodule conv (no extra parameters), so zero_module(controlnet_block.conv) and zero_module(controlnet_block) both zero the same weights/biases; this mixed pattern matches existing usage in monai/networks/nets/controlnet.py.
  • Optional: standardize to one style for readability.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 161 -
171, The code inconsistently applies zero_module to the inner conv vs. the
wrapper: after creating Convolution(...) assigned to controlnet_block you call
zero_module(controlnet_block.conv) before appending to
self.controlnet_down_blocks; change this to a consistent style (either always
pass the wrapper controlnet_block or always pass controlnet_block.conv) across
this file to match the pattern used in monai/networks/nets/controlnet.py — e.g.,
replace zero_module(controlnet_block.conv) with zero_module(controlnet_block)
(or vice versa) for Convolution instances to standardize readability while
keeping behavior identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py`:
- Around line 316-321: Add Google-style docstrings for the helper functions
_validate_input_tensor and _get_input_embeddings: for each function include an
Args section describing every parameter (types and meaning), a Returns section
describing the return value and type, and a Raises section listing exceptions
(e.g., ValueError) and conditions that trigger them; place the docstrings
immediately above the def lines for _validate_input_tensor and
_get_input_embeddings and ensure they mention expected tensor shapes/dtypes, the
emb parameter usage, and the conditions when ValueError is raised.
- Around line 316-321: _validate_input_tensor currently only checks rank and
last dimension; enhance it to also verify the batch size and device to fail
early. Specifically, inside _validate_input_tensor check that tensor.shape[0]
matches emb.shape[0] (raise ValueError mentioning tensor_name and expected batch
size from emb), verify tensor.device == emb.device (raise ValueError if
mismatched), and then return tensor converted with both dtype and device (e.g.,
tensor.to(dtype=emb.dtype, device=emb.device)); keep existing checks for dim and
last dim and retain the include_flag_name-based null check.

---

Nitpick comments:
In `@monai/apps/generation/maisi/networks/controlnet_maisi.py`:
- Around line 271-297: The forward method in controlnet_maisi.py is missing
documentation for the new modality_tensor parameter; update the forward
docstring for ControlNetMaisi.forward (or the class's forward) to include a line
describing modality_tensor as "Tensor representing modality of shape (N, 1)"
(matching DiffusionModelUNetMaisi.forward) and mention its expected dtype and
optional nature (Tensor | None) and role in conditioning; place this description
alongside the existing parameter docs for x, timesteps, controlnet_cond,
context, and class_labels.
- Around line 253-269: Add Google-style docstrings to the three new helper
methods _create_embedding_module, _validate_input_tensor, and
_get_input_embeddings: for each method document all arguments (with
types/meaning), the return value, and any exceptions raised (e.g., ValueError in
_validate_input_tensor) using the Args/Returns/Raises sections; ensure
_validate_input_tensor explains tensor, tensor_name, include_flag_name,
expected_last_dim, emb, the cast to emb.dtype, and the ValueError conditions,
_create_embedding_module documents input_dim and embed_dim and the returned
nn.Sequential module, and _get_input_embeddings documents modality behavior,
when include_modality_input is used, what it returns and that it concatenates
modality embeddings to emb.
- Around line 161-171: The code inconsistently applies zero_module to the inner
conv vs. the wrapper: after creating Convolution(...) assigned to
controlnet_block you call zero_module(controlnet_block.conv) before appending to
self.controlnet_down_blocks; change this to a consistent style (either always
pass the wrapper controlnet_block or always pass controlnet_block.conv) across
this file to match the pattern used in monai/networks/nets/controlnet.py — e.g.,
replace zero_module(controlnet_block.conv) with zero_module(controlnet_block)
(or vice versa) for Convolution instances to standardize readability while
keeping behavior identical.

In `@tests/apps/maisi/networks/test_controlnet_maisi.py`:
- Around line 182-203: Add a test case to verify that
ControlNetMaisi._validate_input_tensor rejects modality tensors with invalid
shape: update or add a test in test_modality_input_missing (or a new test
function) that calls net.forward with modality_tensor=torch.ones((1, 2)) and
asserts a ValueError is raised (matching the existing message or appropriate
one); reference the ControlNetMaisi class and its _validate_input_tensor method
so the test covers the branch where modality_tensor has shape (N, 2) instead of
(N, 1).

In `@tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py`:
- Around line 511-542: Extend the unit tests to assert the shape-validation
branch by adding negative tests that pass malformed auxiliary tensors to
DiffusionModelUNetMaisi: in addition to test_modality_input_missing, add a test
that constructs the same net (include_modality_input=True) and calls net.forward
with a malformed modality_tensor of shape (1,2) (e.g., torch.rand((1,2))) inside
eval_mode(net) and assert it raises ValueError with the "modality_tensor should
be provided" (or the specific shape-related message); likewise, extend or add a
test next to test_additional_input_missing that builds the net with
include_spacing_input=True and calls net.forward with a malformed spacing_tensor
of shape (1,2) and asserts a ValueError with the "spacing_tensor should be
provided" (or specific shape message). Ensure you use the same call pattern
(net.forward(input, timestep, modality_tensor=..., spacing_tensor=...)) and the
same skipUnless(has_einops) decorator so these new tests exercise the
shape-validation branch in DiffusionModelUNetMaisi.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 908991df-c3c5-449c-a0c6-036423c3b0be

📥 Commits

Reviewing files that changed from the base of the PR and between 8a89dd5 and c60d5f2.

📒 Files selected for processing (4)
  • monai/apps/generation/maisi/networks/controlnet_maisi.py
  • monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py
  • tests/apps/maisi/networks/test_controlnet_maisi.py
  • tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py

Comment on lines +316 to +321
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
return tensor.to(dtype=emb.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add Google-style docstrings for new helper definitions.

_validate_input_tensor and _get_input_embeddings are new/modified definitions but have no docstrings describing args/returns/raised exceptions.

As per coding guidelines, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”

Also applies to: 340-361

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, Add Google-style docstrings for the helper functions
_validate_input_tensor and _get_input_embeddings: for each function include an
Args section describing every parameter (types and meaning), a Returns section
describing the return value and type, and a Raises section listing exceptions
(e.g., ValueError) and conditions that trigger them; place the docstrings
immediately above the def lines for _validate_input_tensor and
_get_input_embeddings and ensure they mention expected tensor shapes/dtypes, the
emb parameter usage, and the conditions when ValueError is raised.

Source: Coding guidelines


⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Strengthen tensor validation to prevent runtime crashes.

Line 319 only validates rank/last-dim. Add batch-size and device checks to fail early with clear errors; otherwise torch.cat/Linear can crash later with opaque messages.

Proposed fix
 def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
     if tensor is None:
         raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
     if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
         raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
-    return tensor.to(dtype=emb.dtype)
+    if tensor.shape[0] != emb.shape[0]:
+        raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.")
+    return tensor.to(device=emb.device, dtype=emb.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
return tensor.to(dtype=emb.dtype)
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
if tensor.shape[0] != emb.shape[0]:
raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.")
return tensor.to(device=emb.device, dtype=emb.dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, _validate_input_tensor currently only checks rank and last
dimension; enhance it to also verify the batch size and device to fail early.
Specifically, inside _validate_input_tensor check that tensor.shape[0] matches
emb.shape[0] (raise ValueError mentioning tensor_name and expected batch size
from emb), verify tensor.device == emb.device (raise ValueError if mismatched),
and then return tensor converted with both dtype and device (e.g.,
tensor.to(dtype=emb.dtype, device=emb.device)); keep existing checks for dim and
last dim and retain the include_flag_name-based null check.

Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
@ugbotueferhire

Copy link
Copy Markdown
Contributor Author

Hi @ericspod @KumoLiu

This PR is intended as a core-framework step toward #8170: it adds optional modality conditioning hooks to MAISI so future CT/MRI/multimodal training configs can pass a scalar modality code without changing current CT-only behavior.

I kept the scope intentionally limited to MONAI core APIs and tests, so this does not add pretrained MRI weights, datasets, or model-zoo/tutorial updates. Happy to adjust the scope if you’d prefer this PR to include documentation or config examples as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Expand MAISI to support a wider variety of synthetic data generation

1 participant