Add MAISI modality conditioning hooks#8903
Conversation
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis PR expands MAISI network models to support optional modality conditioning. ControlNetMaisi was refactored to perform initialization directly and now accepts include_modality_input flag with modality_tensor in forward calls. DiffusionModelUNetMaisi gains parallel modality support with validation, embedding creation, and dimension adjustment. Both models concatenate modality embeddings into the conditioning stream when enabled. Tests verify correct shapes with modality input and error handling when modality tensors are missing but expected. Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py (1)
511-542: ⚡ Quick winAdd negative tests for malformed auxiliary tensor shapes.
Current additions test “missing tensor” only. Please also cover wrong-shape cases (e.g.,
modality_tensor(1, 2),spacing_tensor(1, 2)) to exercise the new shape-validation branch.As per coding guidelines, “Ensure new or modified definitions will be covered by existing or new unit tests.”
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py` around lines 511 - 542, Extend the unit tests to assert the shape-validation branch by adding negative tests that pass malformed auxiliary tensors to DiffusionModelUNetMaisi: in addition to test_modality_input_missing, add a test that constructs the same net (include_modality_input=True) and calls net.forward with a malformed modality_tensor of shape (1,2) (e.g., torch.rand((1,2))) inside eval_mode(net) and assert it raises ValueError with the "modality_tensor should be provided" (or the specific shape-related message); likewise, extend or add a test next to test_additional_input_missing that builds the net with include_spacing_input=True and calls net.forward with a malformed spacing_tensor of shape (1,2) and asserts a ValueError with the "spacing_tensor should be provided" (or specific shape message). Ensure you use the same call pattern (net.forward(input, timestep, modality_tensor=..., spacing_tensor=...)) and the same skipUnless(has_einops) decorator so these new tests exercise the shape-validation branch in DiffusionModelUNetMaisi.Source: Coding guidelines
tests/apps/maisi/networks/test_controlnet_maisi.py (1)
182-203: 💤 Low valueConsider adding a test for invalid
modality_tensorshape.The
_validate_input_tensormethod validates shape(N, 1)and raisesValueErrorfor wrong dimensions. A test withmodality_tensor=torch.ones((1, 2))would cover that branch.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/apps/maisi/networks/test_controlnet_maisi.py` around lines 182 - 203, Add a test case to verify that ControlNetMaisi._validate_input_tensor rejects modality tensors with invalid shape: update or add a test in test_modality_input_missing (or a new test function) that calls net.forward with modality_tensor=torch.ones((1, 2)) and asserts a ValueError is raised (matching the existing message or appropriate one); reference the ControlNetMaisi class and its _validate_input_tensor method so the test covers the branch where modality_tensor has shape (N, 2) instead of (N, 1).monai/apps/generation/maisi/networks/controlnet_maisi.py (3)
271-297: ⚡ Quick win
forwarddocstring should documentmodality_tensorparameter.The method signature now includes
modality_tensor, but the docstring (inherited or implicit) doesn't describe it. Should match the pattern used inDiffusionModelUNetMaisi.forwardwhich documentsmodality_tensor: Tensor representing modality of shape (N, 1).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 271 - 297, The forward method in controlnet_maisi.py is missing documentation for the new modality_tensor parameter; update the forward docstring for ControlNetMaisi.forward (or the class's forward) to include a line describing modality_tensor as "Tensor representing modality of shape (N, 1)" (matching DiffusionModelUNetMaisi.forward) and mention its expected dtype and optional nature (Tensor | None) and role in conditioning; place this description alongside the existing parameter docs for x, timesteps, controlnet_cond, context, and class_labels.Source: Coding guidelines
253-269: ⚡ Quick winMissing docstrings for new helper methods.
_create_embedding_module,_validate_input_tensor, and_get_input_embeddingslack docstrings describing their arguments, return values, and raised exceptions.As per coding guidelines: "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
📝 Example docstring for _validate_input_tensor
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): """Validate and cast an optional input tensor. Args: tensor: Input tensor to validate. tensor_name: Name of the tensor for error messages. include_flag_name: Name of the flag that enables this input. expected_last_dim: Expected size of the last dimension. emb: Reference tensor for dtype casting. Returns: Tensor cast to the same dtype as emb. Raises: ValueError: If tensor is None or has incorrect shape. """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 253 - 269, Add Google-style docstrings to the three new helper methods _create_embedding_module, _validate_input_tensor, and _get_input_embeddings: for each method document all arguments (with types/meaning), the return value, and any exceptions raised (e.g., ValueError in _validate_input_tensor) using the Args/Returns/Raises sections; ensure _validate_input_tensor explains tensor, tensor_name, include_flag_name, expected_last_dim, emb, the cast to emb.dtype, and the ValueError conditions, _create_embedding_module documents input_dim and embed_dim and the returned nn.Sequential module, and _get_input_embeddings documents modality behavior, when include_modality_input is used, what it returns and that it concatenates modality embeddings to emb.Source: Coding guidelines
161-171:zero_moduleusage is functionally equivalent for these blocks (no silent behavior difference).
- In
Convolution(..., conv_only=True), the wrapper only contains the submoduleconv(no extra parameters), sozero_module(controlnet_block.conv)andzero_module(controlnet_block)both zero the same weights/biases; this mixed pattern matches existing usage inmonai/networks/nets/controlnet.py.- Optional: standardize to one style for readability.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/apps/generation/maisi/networks/controlnet_maisi.py` around lines 161 - 171, The code inconsistently applies zero_module to the inner conv vs. the wrapper: after creating Convolution(...) assigned to controlnet_block you call zero_module(controlnet_block.conv) before appending to self.controlnet_down_blocks; change this to a consistent style (either always pass the wrapper controlnet_block or always pass controlnet_block.conv) across this file to match the pattern used in monai/networks/nets/controlnet.py — e.g., replace zero_module(controlnet_block.conv) with zero_module(controlnet_block) (or vice versa) for Convolution instances to standardize readability while keeping behavior identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py`:
- Around line 316-321: Add Google-style docstrings for the helper functions
_validate_input_tensor and _get_input_embeddings: for each function include an
Args section describing every parameter (types and meaning), a Returns section
describing the return value and type, and a Raises section listing exceptions
(e.g., ValueError) and conditions that trigger them; place the docstrings
immediately above the def lines for _validate_input_tensor and
_get_input_embeddings and ensure they mention expected tensor shapes/dtypes, the
emb parameter usage, and the conditions when ValueError is raised.
- Around line 316-321: _validate_input_tensor currently only checks rank and
last dimension; enhance it to also verify the batch size and device to fail
early. Specifically, inside _validate_input_tensor check that tensor.shape[0]
matches emb.shape[0] (raise ValueError mentioning tensor_name and expected batch
size from emb), verify tensor.device == emb.device (raise ValueError if
mismatched), and then return tensor converted with both dtype and device (e.g.,
tensor.to(dtype=emb.dtype, device=emb.device)); keep existing checks for dim and
last dim and retain the include_flag_name-based null check.
---
Nitpick comments:
In `@monai/apps/generation/maisi/networks/controlnet_maisi.py`:
- Around line 271-297: The forward method in controlnet_maisi.py is missing
documentation for the new modality_tensor parameter; update the forward
docstring for ControlNetMaisi.forward (or the class's forward) to include a line
describing modality_tensor as "Tensor representing modality of shape (N, 1)"
(matching DiffusionModelUNetMaisi.forward) and mention its expected dtype and
optional nature (Tensor | None) and role in conditioning; place this description
alongside the existing parameter docs for x, timesteps, controlnet_cond,
context, and class_labels.
- Around line 253-269: Add Google-style docstrings to the three new helper
methods _create_embedding_module, _validate_input_tensor, and
_get_input_embeddings: for each method document all arguments (with
types/meaning), the return value, and any exceptions raised (e.g., ValueError in
_validate_input_tensor) using the Args/Returns/Raises sections; ensure
_validate_input_tensor explains tensor, tensor_name, include_flag_name,
expected_last_dim, emb, the cast to emb.dtype, and the ValueError conditions,
_create_embedding_module documents input_dim and embed_dim and the returned
nn.Sequential module, and _get_input_embeddings documents modality behavior,
when include_modality_input is used, what it returns and that it concatenates
modality embeddings to emb.
- Around line 161-171: The code inconsistently applies zero_module to the inner
conv vs. the wrapper: after creating Convolution(...) assigned to
controlnet_block you call zero_module(controlnet_block.conv) before appending to
self.controlnet_down_blocks; change this to a consistent style (either always
pass the wrapper controlnet_block or always pass controlnet_block.conv) across
this file to match the pattern used in monai/networks/nets/controlnet.py — e.g.,
replace zero_module(controlnet_block.conv) with zero_module(controlnet_block)
(or vice versa) for Convolution instances to standardize readability while
keeping behavior identical.
In `@tests/apps/maisi/networks/test_controlnet_maisi.py`:
- Around line 182-203: Add a test case to verify that
ControlNetMaisi._validate_input_tensor rejects modality tensors with invalid
shape: update or add a test in test_modality_input_missing (or a new test
function) that calls net.forward with modality_tensor=torch.ones((1, 2)) and
asserts a ValueError is raised (matching the existing message or appropriate
one); reference the ControlNetMaisi class and its _validate_input_tensor method
so the test covers the branch where modality_tensor has shape (N, 2) instead of
(N, 1).
In `@tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py`:
- Around line 511-542: Extend the unit tests to assert the shape-validation
branch by adding negative tests that pass malformed auxiliary tensors to
DiffusionModelUNetMaisi: in addition to test_modality_input_missing, add a test
that constructs the same net (include_modality_input=True) and calls net.forward
with a malformed modality_tensor of shape (1,2) (e.g., torch.rand((1,2))) inside
eval_mode(net) and assert it raises ValueError with the "modality_tensor should
be provided" (or the specific shape-related message); likewise, extend or add a
test next to test_additional_input_missing that builds the net with
include_spacing_input=True and calls net.forward with a malformed spacing_tensor
of shape (1,2) and asserts a ValueError with the "spacing_tensor should be
provided" (or specific shape message). Ensure you use the same call pattern
(net.forward(input, timestep, modality_tensor=..., spacing_tensor=...)) and the
same skipUnless(has_einops) decorator so these new tests exercise the
shape-validation branch in DiffusionModelUNetMaisi.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 908991df-c3c5-449c-a0c6-036423c3b0be
📒 Files selected for processing (4)
monai/apps/generation/maisi/networks/controlnet_maisi.pymonai/apps/generation/maisi/networks/diffusion_model_unet_maisi.pytests/apps/maisi/networks/test_controlnet_maisi.pytests/apps/maisi/networks/test_diffusion_model_unet_maisi.py
| def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): | ||
| if tensor is None: | ||
| raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.") | ||
| if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim: | ||
| raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.") | ||
| return tensor.to(dtype=emb.dtype) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Add Google-style docstrings for new helper definitions.
_validate_input_tensor and _get_input_embeddings are new/modified definitions but have no docstrings describing args/returns/raised exceptions.
As per coding guidelines, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”
Also applies to: 340-361
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, Add Google-style docstrings for the helper functions
_validate_input_tensor and _get_input_embeddings: for each function include an
Args section describing every parameter (types and meaning), a Returns section
describing the return value and type, and a Raises section listing exceptions
(e.g., ValueError) and conditions that trigger them; place the docstrings
immediately above the def lines for _validate_input_tensor and
_get_input_embeddings and ensure they mention expected tensor shapes/dtypes, the
emb parameter usage, and the conditions when ValueError is raised.
Source: Coding guidelines
Strengthen tensor validation to prevent runtime crashes.
Line 319 only validates rank/last-dim. Add batch-size and device checks to fail early with clear errors; otherwise torch.cat/Linear can crash later with opaque messages.
Proposed fix
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
- return tensor.to(dtype=emb.dtype)
+ if tensor.shape[0] != emb.shape[0]:
+ raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.")
+ return tensor.to(device=emb.device, dtype=emb.dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): | |
| if tensor is None: | |
| raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.") | |
| if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim: | |
| raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.") | |
| return tensor.to(dtype=emb.dtype) | |
| def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): | |
| if tensor is None: | |
| raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.") | |
| if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim: | |
| raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.") | |
| if tensor.shape[0] != emb.shape[0]: | |
| raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.") | |
| return tensor.to(device=emb.device, dtype=emb.dtype) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, _validate_input_tensor currently only checks rank and last
dimension; enhance it to also verify the batch size and device to fail early.
Specifically, inside _validate_input_tensor check that tensor.shape[0] matches
emb.shape[0] (raise ValueError mentioning tensor_name and expected batch size
from emb), verify tensor.device == emb.device (raise ValueError if mismatched),
and then return tensor converted with both dtype and device (e.g.,
tensor.to(dtype=emb.dtype, device=emb.device)); keep existing checks for dim and
last dim and retain the include_flag_name-based null check.
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
|
This PR is intended as a core-framework step toward #8170: it adds optional modality conditioning hooks to MAISI so future CT/MRI/multimodal training configs can pass a scalar modality code without changing current CT-only behavior. I kept the scope intentionally limited to MONAI core APIs and tests, so this does not add pretrained MRI weights, datasets, or model-zoo/tutorial updates. Happy to adjust the scope if you’d prefer this PR to include documentation or config examples as well. |
Fixes #8170.
Description
This PR adds optional modality conditioning hooks to MAISI networks.
DiffusionModelUNetMaisiandControlNetMaisinow supportinclude_modality_inputwith a scalarmodality_tensorof shape(N, 1), enabling CT/MRI/multimodal training configurations while preserving current default behavior.The change also adds validation for missing or incorrectly shaped modality and auxiliary conditioning tensors, plus targeted MAISI network tests covering 2D/3D shape behavior and missing-input errors.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.