Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 201 additions & 25 deletions monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from collections.abc import Sequence

import torch
from torch import nn

from monai.networks.nets.controlnet import ControlNet
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
from monai.networks.blocks import Convolution
from monai.networks.nets.controlnet import ControlNet, ControlNetConditioningEmbedding, zero_module
from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
from monai.utils import ensure_tuple_rep


class ControlNetMaisi(ControlNet):
Expand Down Expand Up @@ -46,6 +49,7 @@ class ControlNetMaisi(ControlNet):
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
include_modality_input: if True, use modality input.
"""

def __init__(
Expand All @@ -70,29 +74,199 @@ def __init__(
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
include_modality_input: bool = False,
) -> None:
super().__init__(
spatial_dims,
in_channels,
num_res_blocks,
num_channels,
attention_levels,
norm_num_groups,
norm_eps,
resblock_updown,
num_head_channels,
with_conditioning,
transformer_num_layers,
cross_attention_dim,
num_class_embeds,
upcast_attention,
conditioning_embedding_in_channels,
conditioning_embedding_num_channels,
include_fc,
use_combined_linear,
use_flash_attention,
)
nn.Module.__init__(self)
if with_conditioning is True and cross_attention_dim is None:
raise ValueError(
"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"to be specified when with_conditioning=True."
)
if cross_attention_dim is not None and with_conditioning is False:
raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.")

if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
raise ValueError(
f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
f" channels={num_channels} and norm_num_groups={norm_num_groups}"
)

if len(num_channels) != len(attention_levels):
raise ValueError(
f"ControlNet expects channels to have the same length as attention_levels, but got "
f"channels={num_channels} and attention_levels={attention_levels}"
)

if isinstance(num_head_channels, int):
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))

if len(num_head_channels) != len(attention_levels):
raise ValueError(
f"num_head_channels should have the same length as attention_levels, but got channels={num_channels} "
f"and attention_levels={attention_levels} . For the i levels without attention,"
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
)

if isinstance(num_res_blocks, int):
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))

if len(num_res_blocks) != len(num_channels):
raise ValueError(
f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={num_channels}."
)

self.in_channels = in_channels
self.block_out_channels = num_channels
self.num_res_blocks = num_res_blocks
self.attention_levels = attention_levels
self.num_head_channels = num_head_channels
self.with_conditioning = with_conditioning
self.use_checkpointing = use_checkpointing
self.include_modality_input = include_modality_input

self.conv_in = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=num_channels[0],
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)

time_embed_dim = num_channels[0] * 4
self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim)

self.num_class_embeds = num_class_embeds
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

new_time_embed_dim = time_embed_dim
if self.include_modality_input:
self.modality_layer = self._create_embedding_module(1, time_embed_dim)
new_time_embed_dim += time_embed_dim

self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
spatial_dims=spatial_dims,
in_channels=conditioning_embedding_in_channels,
channels=conditioning_embedding_num_channels,
out_channels=num_channels[0],
)

self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
output_channel = num_channels[0]

controlnet_block = Convolution(
spatial_dims=spatial_dims,
in_channels=output_channel,
out_channels=output_channel,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
controlnet_block = zero_module(controlnet_block.conv)
self.controlnet_down_blocks.append(controlnet_block)

for i in range(len(num_channels)):
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1

down_block = get_down_block(
spatial_dims=spatial_dims,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=new_time_embed_dim,
num_res_blocks=num_res_blocks[i],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
add_downsample=not is_final_block,
resblock_updown=resblock_updown,
with_attn=(attention_levels[i] and not with_conditioning),
with_cross_attn=(attention_levels[i] and with_conditioning),
num_head_channels=num_head_channels[i],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)
self.down_blocks.append(down_block)

for _ in range(num_res_blocks[i]):
controlnet_block = Convolution(
spatial_dims=spatial_dims,
in_channels=output_channel,
out_channels=output_channel,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = Convolution(
spatial_dims=spatial_dims,
in_channels=output_channel,
out_channels=output_channel,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)

mid_block_channel = num_channels[-1]
self.middle_block = get_mid_block(
spatial_dims=spatial_dims,
in_channels=mid_block_channel,
temb_channels=new_time_embed_dim,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
with_conditioning=with_conditioning,
num_head_channels=num_head_channels[-1],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
)

controlnet_block = Convolution(
spatial_dims=spatial_dims,
in_channels=output_channel,
out_channels=output_channel,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
self.controlnet_mid_block = zero_module(controlnet_block)

def _create_embedding_module(self, input_dim, embed_dim):
model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim))
return model

def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
return tensor.to(dtype=emb.dtype)

def _get_input_embeddings(self, emb, modality):
if self.include_modality_input:
modality = self._validate_input_tensor(modality, "modality_tensor", "include_modality_input", 1, emb)
_emb = self.modality_layer(modality)
emb = torch.cat((emb, _emb), dim=1)
return emb

def forward(
self,
Expand All @@ -102,8 +276,9 @@ def forward(
conditioning_scale: float = 1.0,
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
modality_tensor: torch.Tensor | None = None,
) -> tuple[list[torch.Tensor], torch.Tensor]:
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels, modality_tensor)
h = self._apply_initial_convolution(x)
if self.use_checkpointing:
controlnet_cond = torch.utils.checkpoint.checkpoint(
Expand All @@ -121,7 +296,7 @@ def forward(

return down_block_res_samples, mid_block_res_sample

def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
def _prepare_time_and_class_embedding(self, x, timesteps, class_labels, modality_tensor):
# 1. time
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])

Expand All @@ -139,6 +314,7 @@ def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
class_emb = class_emb.to(dtype=x.dtype)
emb = emb + class_emb

emb = self._get_input_embeddings(emb, modality_tensor)
return emb

def _apply_initial_convolution(self, x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DiffusionModelUNetMaisi(nn.Module):
include_top_region_index_input: If True, use top region index input.
include_bottom_region_index_input: If True, use bottom region index input.
include_spacing_input: If True, use spacing input.
include_modality_input: If True, use modality input.
"""

def __init__(
Expand All @@ -105,6 +106,7 @@ def __init__(
include_top_region_index_input: bool = False,
include_bottom_region_index_input: bool = False,
include_spacing_input: bool = False,
include_modality_input: bool = False,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
self.include_top_region_index_input = include_top_region_index_input
self.include_bottom_region_index_input = include_bottom_region_index_input
self.include_spacing_input = include_spacing_input
self.include_modality_input = include_modality_input

new_time_embed_dim = time_embed_dim
if self.include_top_region_index_input:
Expand All @@ -197,6 +200,9 @@ def __init__(
if self.include_spacing_input:
self.spacing_layer = self._create_embedding_module(3, time_embed_dim)
new_time_embed_dim += time_embed_dim
if self.include_modality_input:
self.modality_layer = self._create_embedding_module(1, time_embed_dim)
new_time_embed_dim += time_embed_dim

# down
self.down_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -307,6 +313,13 @@ def _create_embedding_module(self, input_dim, embed_dim):
model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim))
return model

def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
return tensor.to(dtype=emb.dtype)
Comment on lines +316 to +321

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add Google-style docstrings for new helper definitions.

_validate_input_tensor and _get_input_embeddings are new/modified definitions but have no docstrings describing args/returns/raised exceptions.

As per coding guidelines, “Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.”

Also applies to: 340-361

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, Add Google-style docstrings for the helper functions
_validate_input_tensor and _get_input_embeddings: for each function include an
Args section describing every parameter (types and meaning), a Returns section
describing the return value and type, and a Raises section listing exceptions
(e.g., ValueError) and conditions that trigger them; place the docstrings
immediately above the def lines for _validate_input_tensor and
_get_input_embeddings and ensure they mention expected tensor shapes/dtypes, the
emb parameter usage, and the conditions when ValueError is raised.

Source: Coding guidelines


⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Strengthen tensor validation to prevent runtime crashes.

Line 319 only validates rank/last-dim. Add batch-size and device checks to fail early with clear errors; otherwise torch.cat/Linear can crash later with opaque messages.

Proposed fix
 def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
     if tensor is None:
         raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
     if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
         raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
-    return tensor.to(dtype=emb.dtype)
+    if tensor.shape[0] != emb.shape[0]:
+        raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.")
+    return tensor.to(device=emb.device, dtype=emb.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
return tensor.to(dtype=emb.dtype)
def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb):
if tensor is None:
raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.")
if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim:
raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.")
if tensor.shape[0] != emb.shape[0]:
raise ValueError(f"{tensor_name} should have batch size {emb.shape[0]}, got {tensor.shape[0]}.")
return tensor.to(device=emb.device, dtype=emb.dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py` around
lines 316 - 321, _validate_input_tensor currently only checks rank and last
dimension; enhance it to also verify the batch size and device to fail early.
Specifically, inside _validate_input_tensor check that tensor.shape[0] matches
emb.shape[0] (raise ValueError mentioning tensor_name and expected batch size
from emb), verify tensor.device == emb.device (raise ValueError if mismatched),
and then return tensor converted with both dtype and device (e.g.,
tensor.to(dtype=emb.dtype, device=emb.device)); keep existing checks for dim and
last dim and retain the include_flag_name-based null check.


def _get_time_and_class_embedding(self, x, timesteps, class_labels):
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])

Expand All @@ -324,16 +337,27 @@ def _get_time_and_class_embedding(self, x, timesteps, class_labels):
emb += class_emb
return emb

def _get_input_embeddings(self, emb, top_index, bottom_index, spacing):
def _get_input_embeddings(self, emb, top_index, bottom_index, spacing, modality):
if self.include_top_region_index_input:
top_index = self._validate_input_tensor(
top_index, "top_region_index_tensor", "include_top_region_index_input", 4, emb
)
_emb = self.top_region_index_layer(top_index)
emb = torch.cat((emb, _emb), dim=1)
if self.include_bottom_region_index_input:
bottom_index = self._validate_input_tensor(
bottom_index, "bottom_region_index_tensor", "include_bottom_region_index_input", 4, emb
)
_emb = self.bottom_region_index_layer(bottom_index)
emb = torch.cat((emb, _emb), dim=1)
if self.include_spacing_input:
spacing = self._validate_input_tensor(spacing, "spacing_tensor", "include_spacing_input", 3, emb)
_emb = self.spacing_layer(spacing)
emb = torch.cat((emb, _emb), dim=1)
if self.include_modality_input:
modality = self._validate_input_tensor(modality, "modality_tensor", "include_modality_input", 1, emb)
_emb = self.modality_layer(modality)
emb = torch.cat((emb, _emb), dim=1)
return emb

def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals):
Expand Down Expand Up @@ -376,6 +400,7 @@ def forward(
top_region_index_tensor: torch.Tensor | None = None,
bottom_region_index_tensor: torch.Tensor | None = None,
spacing_tensor: torch.Tensor | None = None,
modality_tensor: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass through the UNet model.
Expand All @@ -390,13 +415,16 @@ def forward(
top_region_index_tensor: Tensor representing top region index of shape (N, 4).
bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4).
spacing_tensor: Tensor representing spacing of shape (N, 3).
modality_tensor: Tensor representing modality of shape (N, 1).

Returns:
A tensor representing the output of the UNet model.
"""

emb = self._get_time_and_class_embedding(x, timesteps, class_labels)
emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor)
emb = self._get_input_embeddings(
emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
)
h = self.conv_in(x)
h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals)
h = self.middle_block(h, emb, context)
Expand Down
Loading
Loading