diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 7c13fd7bc6..122b2d3837 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -14,9 +14,12 @@ from collections.abc import Sequence import torch +from torch import nn -from monai.networks.nets.controlnet import ControlNet -from monai.networks.nets.diffusion_model_unet import get_timestep_embedding +from monai.networks.blocks import Convolution +from monai.networks.nets.controlnet import ControlNet, ControlNetConditioningEmbedding, zero_module +from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding +from monai.utils import ensure_tuple_rep class ControlNetMaisi(ControlNet): @@ -46,6 +49,7 @@ class ControlNetMaisi(ControlNet): include_fc: whether to include the final linear layer. Default to False. use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + include_modality_input: if True, use modality input. """ def __init__( @@ -70,29 +74,199 @@ def __init__( include_fc: bool = False, use_combined_linear: bool = False, use_flash_attention: bool = False, + include_modality_input: bool = False, ) -> None: - super().__init__( - spatial_dims, - in_channels, - num_res_blocks, - num_channels, - attention_levels, - norm_num_groups, - norm_eps, - resblock_updown, - num_head_channels, - with_conditioning, - transformer_num_layers, - cross_attention_dim, - num_class_embeds, - upcast_attention, - conditioning_embedding_in_channels, - conditioning_embedding_num_channels, - include_fc, - use_combined_linear, - use_flash_attention, - ) + nn.Module.__init__(self) + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") + + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError( + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={num_channels} and norm_num_groups={norm_num_groups}" + ) + + if len(num_channels) != len(attention_levels): + raise ValueError( + f"ControlNet expects channels to have the same length as attention_levels, but got " + f"channels={num_channels} and attention_levels={attention_levels}" + ) + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + f"num_head_channels should have the same length as attention_levels, but got channels={num_channels} " + f"and attention_levels={attention_levels} . For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={num_channels}." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning self.use_checkpointing = use_checkpointing + self.include_modality_input = include_modality_input + + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + time_embed_dim = num_channels[0] * 4 + self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim) + + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + new_time_embed_dim = time_embed_dim + if self.include_modality_input: + self.modality_layer = self._create_embedding_module(1, time_embed_dim) + new_time_embed_dim += time_embed_dim + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + spatial_dims=spatial_dims, + in_channels=conditioning_embedding_in_channels, + channels=conditioning_embedding_num_channels, + out_channels=num_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block.conv) + self.controlnet_down_blocks.append(controlnet_block) + + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, + ) + self.down_blocks.append(down_block) + + for _ in range(num_res_blocks[i]): + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + if not is_final_block: + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + mid_block_channel = num_channels[-1] + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=mid_block_channel, + temb_channels=new_time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, + ) + + controlnet_block = Convolution( + spatial_dims=spatial_dims, + in_channels=output_channel, + out_channels=output_channel, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + self.controlnet_mid_block = zero_module(controlnet_block) + + def _create_embedding_module(self, input_dim, embed_dim): + model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim)) + return model + + def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): + if tensor is None: + raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.") + if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim: + raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.") + return tensor.to(dtype=emb.dtype) + + def _get_input_embeddings(self, emb, modality): + if self.include_modality_input: + modality = self._validate_input_tensor(modality, "modality_tensor", "include_modality_input", 1, emb) + _emb = self.modality_layer(modality) + emb = torch.cat((emb, _emb), dim=1) + return emb def forward( self, @@ -102,8 +276,9 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, + modality_tensor: torch.Tensor | None = None, ) -> tuple[list[torch.Tensor], torch.Tensor]: - emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) + emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels, modality_tensor) h = self._apply_initial_convolution(x) if self.use_checkpointing: controlnet_cond = torch.utils.checkpoint.checkpoint( @@ -121,7 +296,7 @@ def forward( return down_block_res_samples, mid_block_res_sample - def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): + def _prepare_time_and_class_embedding(self, x, timesteps, class_labels, modality_tensor): # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) @@ -139,6 +314,7 @@ def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb + emb = self._get_input_embeddings(emb, modality_tensor) return emb def _apply_initial_convolution(self, x): diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index 4eac17b870..85c1c4799b 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -79,6 +79,7 @@ class DiffusionModelUNetMaisi(nn.Module): include_top_region_index_input: If True, use top region index input. include_bottom_region_index_input: If True, use bottom region index input. include_spacing_input: If True, use spacing input. + include_modality_input: If True, use modality input. """ def __init__( @@ -105,6 +106,7 @@ def __init__( include_top_region_index_input: bool = False, include_bottom_region_index_input: bool = False, include_spacing_input: bool = False, + include_modality_input: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -186,6 +188,7 @@ def __init__( self.include_top_region_index_input = include_top_region_index_input self.include_bottom_region_index_input = include_bottom_region_index_input self.include_spacing_input = include_spacing_input + self.include_modality_input = include_modality_input new_time_embed_dim = time_embed_dim if self.include_top_region_index_input: @@ -197,6 +200,9 @@ def __init__( if self.include_spacing_input: self.spacing_layer = self._create_embedding_module(3, time_embed_dim) new_time_embed_dim += time_embed_dim + if self.include_modality_input: + self.modality_layer = self._create_embedding_module(1, time_embed_dim) + new_time_embed_dim += time_embed_dim # down self.down_blocks = nn.ModuleList([]) @@ -307,6 +313,13 @@ def _create_embedding_module(self, input_dim, embed_dim): model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim)) return model + def _validate_input_tensor(self, tensor, tensor_name, include_flag_name, expected_last_dim, emb): + if tensor is None: + raise ValueError(f"{tensor_name} should be provided when {include_flag_name} is True.") + if tensor.dim() != 2 or tensor.shape[1] != expected_last_dim: + raise ValueError(f"{tensor_name} should have shape (N, {expected_last_dim}), got {tuple(tensor.shape)}.") + return tensor.to(dtype=emb.dtype) + def _get_time_and_class_embedding(self, x, timesteps, class_labels): t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) @@ -324,16 +337,27 @@ def _get_time_and_class_embedding(self, x, timesteps, class_labels): emb += class_emb return emb - def _get_input_embeddings(self, emb, top_index, bottom_index, spacing): + def _get_input_embeddings(self, emb, top_index, bottom_index, spacing, modality): if self.include_top_region_index_input: + top_index = self._validate_input_tensor( + top_index, "top_region_index_tensor", "include_top_region_index_input", 4, emb + ) _emb = self.top_region_index_layer(top_index) emb = torch.cat((emb, _emb), dim=1) if self.include_bottom_region_index_input: + bottom_index = self._validate_input_tensor( + bottom_index, "bottom_region_index_tensor", "include_bottom_region_index_input", 4, emb + ) _emb = self.bottom_region_index_layer(bottom_index) emb = torch.cat((emb, _emb), dim=1) if self.include_spacing_input: + spacing = self._validate_input_tensor(spacing, "spacing_tensor", "include_spacing_input", 3, emb) _emb = self.spacing_layer(spacing) emb = torch.cat((emb, _emb), dim=1) + if self.include_modality_input: + modality = self._validate_input_tensor(modality, "modality_tensor", "include_modality_input", 1, emb) + _emb = self.modality_layer(modality) + emb = torch.cat((emb, _emb), dim=1) return emb def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals): @@ -376,6 +400,7 @@ def forward( top_region_index_tensor: torch.Tensor | None = None, bottom_region_index_tensor: torch.Tensor | None = None, spacing_tensor: torch.Tensor | None = None, + modality_tensor: torch.Tensor | None = None, ) -> torch.Tensor: """ Forward pass through the UNet model. @@ -390,13 +415,16 @@ def forward( top_region_index_tensor: Tensor representing top region index of shape (N, 4). bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4). spacing_tensor: Tensor representing spacing of shape (N, 3). + modality_tensor: Tensor representing modality of shape (N, 1). Returns: A tensor representing the output of the UNet model. """ emb = self._get_time_and_class_embedding(x, timesteps, class_labels) - emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) + emb = self._get_input_embeddings( + emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor + ) h = self.conv_in(x) h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals) h = self.middle_block(h, emb, context) diff --git a/tests/apps/maisi/networks/test_controlnet_maisi.py b/tests/apps/maisi/networks/test_controlnet_maisi.py index 5868d1e308..2d5d8a2729 100644 --- a/tests/apps/maisi/networks/test_controlnet_maisi.py +++ b/tests/apps/maisi/networks/test_controlnet_maisi.py @@ -141,6 +141,22 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) self.assertEqual(result[1].shape, expected_shape) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_shape_with_modality_input(self, input_param, expected_num_down_blocks_residuals, expected_shape): + input_param = dict(input_param) + input_param["include_modality_input"] = True + net = ControlNetMaisi(**input_param) + with eval_mode(net): + x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16)) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = ( + torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32)) + ) + result = net.forward(x, timesteps, controlnet_cond, modality_tensor=torch.ones((1, 1))) + self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) + self.assertEqual(result[1].shape, expected_shape) + @parameterized.expand(TEST_CASES_CONDITIONAL) @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): @@ -163,6 +179,25 @@ def test_error_input(self, input_param, expected_error): runtime_error = context.exception self.assertEqual(str(runtime_error), expected_error) + @skipUnless(has_einops, "Requires einops") + def test_modality_input_missing(self): + net = ControlNetMaisi( + spatial_dims=2, + in_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + num_head_channels=8, + norm_num_groups=8, + conditioning_embedding_in_channels=1, + conditioning_embedding_num_channels=(8, 8), + use_checkpointing=False, + include_modality_input=True, + ) + with self.assertRaisesRegex(ValueError, "modality_tensor should be provided"): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32))) + if __name__ == "__main__": unittest.main() diff --git a/tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py b/tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py index f9384e6d82..3c2c3d626f 100644 --- a/tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py +++ b/tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py @@ -491,9 +491,11 @@ def test_conditioned_2d_models_shape(self, input_param): @parameterized.expand(UNCOND_CASES_2D) @skipUnless(has_einops, "Requires einops") def test_shape_with_additional_inputs(self, input_param): + input_param = dict(input_param) input_param["include_top_region_index_input"] = True input_param["include_bottom_region_index_input"] = True input_param["include_spacing_input"] = True + input_param["include_modality_input"] = True net = DiffusionModelUNetMaisi(**input_param) with eval_mode(net): result = net.forward( @@ -502,9 +504,42 @@ def test_shape_with_additional_inputs(self, input_param): top_region_index_tensor=torch.rand((1, 4)), bottom_region_index_tensor=torch.rand((1, 4)), spacing_tensor=torch.rand((1, 3)), + modality_tensor=torch.ones((1, 1)), ) self.assertEqual(result.shape, (1, 1, 16, 16)) + @skipUnless(has_einops, "Requires einops") + def test_modality_input_missing(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + include_modality_input=True, + ) + with self.assertRaisesRegex(ValueError, "modality_tensor should be provided"): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + + @skipUnless(has_einops, "Requires einops") + def test_additional_input_missing(self): + net = DiffusionModelUNetMaisi( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + include_spacing_input=True, + ) + with self.assertRaisesRegex(ValueError, "spacing_tensor should be provided"): + with eval_mode(net): + net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) + class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @@ -569,9 +604,11 @@ def test_right_dropout(self, input_param): @parameterized.expand(UNCOND_CASES_3D) @skipUnless(has_einops, "Requires einops") def test_shape_with_additional_inputs(self, input_param): + input_param = dict(input_param) input_param["include_top_region_index_input"] = True input_param["include_bottom_region_index_input"] = True input_param["include_spacing_input"] = True + input_param["include_modality_input"] = True net = DiffusionModelUNetMaisi(**input_param) with eval_mode(net): result = net.forward( @@ -580,6 +617,7 @@ def test_shape_with_additional_inputs(self, input_param): top_region_index_tensor=torch.rand((1, 4)), bottom_region_index_tensor=torch.rand((1, 4)), spacing_tensor=torch.rand((1, 3)), + modality_tensor=torch.ones((1, 1)), ) self.assertEqual(result.shape, (1, 1, 16, 16, 16))