Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
from monai.config import DtypeLike
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils.enums import TransformBackends
from monai.utils.enums import TraceKeys, TransformBackends
from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
from monai.utils.module import min_version, optional_import
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
Expand Down Expand Up @@ -836,7 +838,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
return out


class NormalizeIntensity(Transform):
class NormalizeIntensity(InvertibleTransform):
"""
Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.
Expand All @@ -846,6 +848,11 @@ class NormalizeIntensity(Transform):
be the number of image channels if they are not None.
If the input is not of floating point type, it will be converted to float32

The subtrahend and divisor actually used (whether provided or computed) are stored in the
transform's meta information, so the transform is invertible via :meth:`inverse`, recovering
``img * divisor + subtrahend``. Inversion is not supported when ``nonzero=True``, because the
zero-voxel mask would be required to reverse the operation exactly.

Args:
subtrahend: the amount to subtract by (usually the mean).
divisor: the amount to divide by (usually the standard deviation).
Expand Down Expand Up @@ -885,14 +892,14 @@ def _std(x):
x = torch.std(x.float(), unbiased=False)
return x.item() if x.numel() == 1 else x

def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
def _normalize(self, img: NdarrayOrTensor, sub=None, div=None):
img, *_ = convert_data_type(img, dtype=torch.float32)

if self.nonzero:
slices = img != 0
masked_img = img[slices]
if not slices.any():
return img
return img, None, None
else:
Comment on lines +902 to 903

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid storing None in invertible extra_info.

For nonzero=True + all-zero mask, _normalize() returns None stats, then _push_transform_with_stats() serializes them into extra_info. That can break transform-history collation in batched pipelines.

Suggested fix
-            if not slices.any():
-                return img, None, None
+            if not slices.any():
+                # keep metadata collate-safe by storing identity normalization params
+                return img, 0.0, 1.0

Also applies to: 977-980

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/intensity/array.py` around lines 902 - 903, The code is
storing None stats into invertible extra_info (via _push_transform_with_stats)
when _normalize() returns None for nonzero=True with all-zero masks; update the
logic so that when _normalize(...) returns None you do not add any stats entries
to extra_info (i.e., skip serializing or setting keys) — modify callers (e.g.,
where _normalize is invoked in the normalization transform) or update
_push_transform_with_stats to check for None and return early/omit adding the
stats entry so extra_info never contains None values.

slices = None
masked_img = img
Expand All @@ -917,7 +924,8 @@ def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTenso
img[slices] = (masked_img - _sub) / _div
else:
img = (img - _sub) / _div
return img
# Return the subtrahend/divisor actually used so the transform can be inverted.
return img, _sub, _div

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Expand All @@ -926,6 +934,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_t: torch.Tensor = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore[assignment]
dtype = self.dtype or img.dtype
img_len = len(img_t)
# Subtrahend/divisor used per channel (channel_wise) or once (global), kept for inverse().
subs: list = []
divs: list = []
if self.channel_wise:
if self.subtrahend is not None and len(self.subtrahend) != img_len:
raise ValueError(f"img has {img_len} channels, but subtrahend has {len(self.subtrahend)} components.")
Expand All @@ -936,15 +947,66 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_t, *_ = convert_data_type(img_t, dtype=torch.float32)

for i, d in enumerate(img_t):
img_t[i] = self._normalize( # type: ignore
img_t[i], _sub, _div = self._normalize( # type: ignore
d,
sub=self.subtrahend[i] if self.subtrahend is not None else None,
div=self.divisor[i] if self.divisor is not None else None,
)
subs.append(_sub)
divs.append(_div)
else:
img_t = self._normalize(img_t, self.subtrahend, self.divisor) # type: ignore[assignment]
img_t, _sub, _div = self._normalize(img_t, self.subtrahend, self.divisor) # type: ignore[assignment]
subs.append(_sub)
divs.append(_div)

out = convert_to_dst_type(img_t, img_t, dtype=dtype)[0]
out = self._push_transform_with_stats(out, subs, divs)
return out

def _to_storable(self, value):
"""Convert a subtrahend/divisor to something storable in transform meta."""
if isinstance(value, torch.Tensor):
return value.detach().cpu()
if isinstance(value, np.ndarray):
return torch.as_tensor(value)
return value # python/numpy scalar

def _push_transform_with_stats(self, out, subs: list, divs: list):
if not isinstance(out, MetaTensor) or not get_track_meta():
return out
extra_info = {
"sub": [self._to_storable(s) for s in subs],
"div": [self._to_storable(d) for d in divs],
"channel_wise": self.channel_wise,
"nonzero": self.nonzero,
}
self.push_transform(out, extra_info=extra_info)
return out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
info = transform[TraceKeys.EXTRA_INFO]
if info["nonzero"]:
raise NotImplementedError(
"NormalizeIntensity.inverse is not supported when nonzero=True, because the "
"zero-voxel mask is needed to reverse the normalization exactly."
)
subs, divs = info["sub"], info["div"]
out: torch.Tensor = convert_to_tensor(data, track_meta=get_track_meta()) # type: ignore[assignment]

def _restore(x, sub, div):
sub, *_ = convert_to_dst_type(sub, x)
div, *_ = convert_to_dst_type(div, x)
return x * div + sub

if info["channel_wise"]:
for i in range(len(out)):
if subs[i] is None or divs[i] is None: # all-zero channel skipped on the forward pass
continue
out[i] = _restore(out[i], subs[i], divs[i])
else:
if subs[0] is not None and divs[0] is not None:
out = _restore(out, subs[0], divs[0])
return out


Expand Down
12 changes: 10 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
StdShiftIntensity,
ThresholdIntensity,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import convert_to_tensor, ensure_tuple, ensure_tuple_rep
Expand Down Expand Up @@ -791,11 +792,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class NormalizeIntensityd(MapTransform):
class NormalizeIntensityd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.
This transform can normalize only non-zero values or entire image, and can also calculate
mean and std on each channel separately.
mean and std on each channel separately. It is invertible via :meth:`inverse` (except when
``nonzero=True``); see :py:class:`monai.transforms.NormalizeIntensity`.

Args:
keys: keys of the corresponding items to be transformed.
Expand Down Expand Up @@ -830,6 +832,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d[key] = self.normalizer(d[key])
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.normalizer.inverse(d[key])
return d


class ThresholdIntensityd(MapTransform):
"""
Expand Down
25 changes: 25 additions & 0 deletions tests/transforms/test_normalize_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from parameterized import parameterized

from monai.data import MetaTensor, set_track_meta
from monai.transforms import NormalizeIntensity
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

Expand Down Expand Up @@ -138,6 +139,30 @@ def test_value_errors(self, im_type):
with self.assertRaises(ValueError):
normalizer(input_data)

@parameterized.expand(
[
["global_computed", {}],
["channelwise_computed", {"channel_wise": True}],
["global_explicit", {"subtrahend": 2.0, "divisor": 3.0}],
["channelwise_explicit", {"subtrahend": [1.0, 2.0, 3.0], "divisor": [2.0, 3.0, 4.0], "channel_wise": True}],
]
)
def test_inverse(self, _, args):
set_track_meta(True)
img = MetaTensor(torch.randn(3, 6, 6) * 5 + 2)
normalizer = NormalizeIntensity(**args)
out = normalizer(img.clone())
inv = normalizer.inverse(out)
assert_allclose(inv, img, type_test=False, rtol=1e-4, atol=1e-4)

def test_inverse_nonzero_not_implemented(self):
set_track_meta(True)
img = MetaTensor(torch.randn(2, 5, 5))
normalizer = NormalizeIntensity(nonzero=True)
out = normalizer(img.clone())
with self.assertRaises(NotImplementedError):
normalizer.inverse(out)
Comment on lines +151 to +164

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Restore track_meta global state after each test.

These tests set global metadata tracking to True and never restore prior state, which can make later tests order-dependent.

Suggested fix
-from monai.data import MetaTensor, set_track_meta
+from monai.data import MetaTensor, get_track_meta, set_track_meta
...
     def test_inverse(self, _, args):
-        set_track_meta(True)
+        prev = get_track_meta()
+        self.addCleanup(set_track_meta, prev)
+        set_track_meta(True)
...
     def test_inverse_nonzero_not_implemented(self):
-        set_track_meta(True)
+        prev = get_track_meta()
+        self.addCleanup(set_track_meta, prev)
+        set_track_meta(True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
set_track_meta(True)
img = MetaTensor(torch.randn(3, 6, 6) * 5 + 2)
normalizer = NormalizeIntensity(**args)
out = normalizer(img.clone())
inv = normalizer.inverse(out)
assert_allclose(inv, img, type_test=False, rtol=1e-4, atol=1e-4)
def test_inverse_nonzero_not_implemented(self):
set_track_meta(True)
img = MetaTensor(torch.randn(2, 5, 5))
normalizer = NormalizeIntensity(nonzero=True)
out = normalizer(img.clone())
with self.assertRaises(NotImplementedError):
normalizer.inverse(out)
prev = get_track_meta()
self.addCleanup(set_track_meta, prev)
set_track_meta(True)
img = MetaTensor(torch.randn(3, 6, 6) * 5 + 2)
normalizer = NormalizeIntensity(**args)
out = normalizer(img.clone())
inv = normalizer.inverse(out)
assert_allclose(inv, img, type_test=False, rtol=1e-4, atol=1e-4)
def test_inverse_nonzero_not_implemented(self):
prev = get_track_meta()
self.addCleanup(set_track_meta, prev)
set_track_meta(True)
img = MetaTensor(torch.randn(2, 5, 5))
normalizer = NormalizeIntensity(nonzero=True)
out = normalizer(img.clone())
with self.assertRaises(NotImplementedError):
normalizer.inverse(out)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/transforms/test_normalize_intensity.py` around lines 151 - 164, These
tests call set_track_meta(True) but never restore previous global state; wrap
the body of each test (the ones creating MetaTensor and using
NormalizeIntensity/out/inv) by capturing the prior state with get_track_meta()
(or equivalent getter), call set_track_meta(True), then run the test assertions
in a try/finally and restore the original state with set_track_meta(prev) in the
finally block so global metadata tracking is always returned to its prior value
after test execution.



if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions tests/transforms/test_normalize_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data import MetaTensor, set_track_meta
from monai.transforms import NormalizeIntensityd
from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

Expand Down Expand Up @@ -76,6 +78,17 @@ def test_channel_wise(self, im_type):
expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])
assert_allclose(normalized, im_type(expected), type_test="tensor")

@parameterized.expand([["global", {}], ["channelwise", {"channel_wise": True}]])
def test_inverse(self, _, args):
set_track_meta(True)
key = "img"
normalizer = NormalizeIntensityd(keys=key, **args)
data = {key: MetaTensor(torch.randn(3, 6, 6) * 4 + 1)}
original = data[key].clone()
out = normalizer(dict(data))
inv = normalizer.inverse(out)
assert_allclose(inv[key], original, type_test=False, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
unittest.main()
Loading