diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 23a57ae9fb..3f4c8cd4a4 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,13 +26,15 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import MetaTensor from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where -from monai.utils.enums import TransformBackends +from monai.utils.enums import TraceKeys, TransformBackends from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple from monai.utils.module import min_version, optional_import from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype @@ -836,7 +838,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen return out -class NormalizeIntensity(Transform): +class NormalizeIntensity(InvertibleTransform): """ Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`. Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided. @@ -846,6 +848,11 @@ class NormalizeIntensity(Transform): be the number of image channels if they are not None. If the input is not of floating point type, it will be converted to float32 + The subtrahend and divisor actually used (whether provided or computed) are stored in the + transform's meta information, so the transform is invertible via :meth:`inverse`, recovering + ``img * divisor + subtrahend``. Inversion is not supported when ``nonzero=True``, because the + zero-voxel mask would be required to reverse the operation exactly. + Args: subtrahend: the amount to subtract by (usually the mean). divisor: the amount to divide by (usually the standard deviation). @@ -885,14 +892,14 @@ def _std(x): x = torch.std(x.float(), unbiased=False) return x.item() if x.numel() == 1 else x - def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor: + def _normalize(self, img: NdarrayOrTensor, sub=None, div=None): img, *_ = convert_data_type(img, dtype=torch.float32) if self.nonzero: slices = img != 0 masked_img = img[slices] if not slices.any(): - return img + return img, None, None else: slices = None masked_img = img @@ -917,7 +924,8 @@ def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTenso img[slices] = (masked_img - _sub) / _div else: img = (img - _sub) / _div - return img + # Return the subtrahend/divisor actually used so the transform can be inverted. + return img, _sub, _div def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -926,6 +934,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img_t: torch.Tensor = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore[assignment] dtype = self.dtype or img.dtype img_len = len(img_t) + # Subtrahend/divisor used per channel (channel_wise) or once (global), kept for inverse(). + subs: list = [] + divs: list = [] if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != img_len: raise ValueError(f"img has {img_len} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -936,15 +947,66 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img_t, *_ = convert_data_type(img_t, dtype=torch.float32) for i, d in enumerate(img_t): - img_t[i] = self._normalize( # type: ignore + img_t[i], _sub, _div = self._normalize( # type: ignore d, sub=self.subtrahend[i] if self.subtrahend is not None else None, div=self.divisor[i] if self.divisor is not None else None, ) + subs.append(_sub) + divs.append(_div) else: - img_t = self._normalize(img_t, self.subtrahend, self.divisor) # type: ignore[assignment] + img_t, _sub, _div = self._normalize(img_t, self.subtrahend, self.divisor) # type: ignore[assignment] + subs.append(_sub) + divs.append(_div) out = convert_to_dst_type(img_t, img_t, dtype=dtype)[0] + out = self._push_transform_with_stats(out, subs, divs) + return out + + def _to_storable(self, value): + """Convert a subtrahend/divisor to something storable in transform meta.""" + if isinstance(value, torch.Tensor): + return value.detach().cpu() + if isinstance(value, np.ndarray): + return torch.as_tensor(value) + return value # python/numpy scalar + + def _push_transform_with_stats(self, out, subs: list, divs: list): + if not isinstance(out, MetaTensor) or not get_track_meta(): + return out + extra_info = { + "sub": [self._to_storable(s) for s in subs], + "div": [self._to_storable(d) for d in divs], + "channel_wise": self.channel_wise, + "nonzero": self.nonzero, + } + self.push_transform(out, extra_info=extra_info) + return out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + info = transform[TraceKeys.EXTRA_INFO] + if info["nonzero"]: + raise NotImplementedError( + "NormalizeIntensity.inverse is not supported when nonzero=True, because the " + "zero-voxel mask is needed to reverse the normalization exactly." + ) + subs, divs = info["sub"], info["div"] + out: torch.Tensor = convert_to_tensor(data, track_meta=get_track_meta()) # type: ignore[assignment] + + def _restore(x, sub, div): + sub, *_ = convert_to_dst_type(sub, x) + div, *_ = convert_to_dst_type(div, x) + return x * div + sub + + if info["channel_wise"]: + for i in range(len(out)): + if subs[i] is None or divs[i] is None: # all-zero channel skipped on the forward pass + continue + out[i] = _restore(out[i], subs[i], divs[i]) + else: + if subs[0] is not None and divs[0] is not None: + out = _restore(out, subs[0], divs[0]) return out diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 0c25d4ac99..4fa5f0e2da 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -60,6 +60,7 @@ StdShiftIntensity, ThresholdIntensity, ) +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import is_positive from monai.utils import convert_to_tensor, ensure_tuple, ensure_tuple_rep @@ -791,11 +792,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d -class NormalizeIntensityd(MapTransform): +class NormalizeIntensityd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`. This transform can normalize only non-zero values or entire image, and can also calculate - mean and std on each channel separately. + mean and std on each channel separately. It is invertible via :meth:`inverse` (except when + ``nonzero=True``); see :py:class:`monai.transforms.NormalizeIntensity`. Args: keys: keys of the corresponding items to be transformed. @@ -830,6 +832,12 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = self.normalizer(d[key]) return d + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.normalizer.inverse(d[key]) + return d + class ThresholdIntensityd(MapTransform): """ diff --git a/tests/transforms/test_normalize_intensity.py b/tests/transforms/test_normalize_intensity.py index c58bc587f2..ee0295a1e2 100644 --- a/tests/transforms/test_normalize_intensity.py +++ b/tests/transforms/test_normalize_intensity.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor, set_track_meta from monai.transforms import NormalizeIntensity from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -138,6 +139,30 @@ def test_value_errors(self, im_type): with self.assertRaises(ValueError): normalizer(input_data) + @parameterized.expand( + [ + ["global_computed", {}], + ["channelwise_computed", {"channel_wise": True}], + ["global_explicit", {"subtrahend": 2.0, "divisor": 3.0}], + ["channelwise_explicit", {"subtrahend": [1.0, 2.0, 3.0], "divisor": [2.0, 3.0, 4.0], "channel_wise": True}], + ] + ) + def test_inverse(self, _, args): + set_track_meta(True) + img = MetaTensor(torch.randn(3, 6, 6) * 5 + 2) + normalizer = NormalizeIntensity(**args) + out = normalizer(img.clone()) + inv = normalizer.inverse(out) + assert_allclose(inv, img, type_test=False, rtol=1e-4, atol=1e-4) + + def test_inverse_nonzero_not_implemented(self): + set_track_meta(True) + img = MetaTensor(torch.randn(2, 5, 5)) + normalizer = NormalizeIntensity(nonzero=True) + out = normalizer(img.clone()) + with self.assertRaises(NotImplementedError): + normalizer.inverse(out) + if __name__ == "__main__": unittest.main() diff --git a/tests/transforms/test_normalize_intensityd.py b/tests/transforms/test_normalize_intensityd.py index b8e4c7bca8..e11cce1c62 100644 --- a/tests/transforms/test_normalize_intensityd.py +++ b/tests/transforms/test_normalize_intensityd.py @@ -14,8 +14,10 @@ import unittest import numpy as np +import torch from parameterized import parameterized +from monai.data import MetaTensor, set_track_meta from monai.transforms import NormalizeIntensityd from tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -76,6 +78,17 @@ def test_channel_wise(self, im_type): expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]]) assert_allclose(normalized, im_type(expected), type_test="tensor") + @parameterized.expand([["global", {}], ["channelwise", {"channel_wise": True}]]) + def test_inverse(self, _, args): + set_track_meta(True) + key = "img" + normalizer = NormalizeIntensityd(keys=key, **args) + data = {key: MetaTensor(torch.randn(3, 6, 6) * 4 + 1)} + original = data[key].clone() + out = normalizer(dict(data)) + inv = normalizer.inverse(out) + assert_allclose(inv[key], original, type_test=False, rtol=1e-4, atol=1e-4) + if __name__ == "__main__": unittest.main()