diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index a85eb95c20..515d51f1f2 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1549,10 +1549,26 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header[MetaKeys.SPATIAL_SHAPE] = header["sizes"].copy() [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 - ) + if self.channel_dim is None: # default to "no_channel" or 0 + # Use the NRRD 'kinds' field to detect non-spatial (channel) axes. + # Spatial kinds are 'domain' and 'space'; anything else (e.g. 'list', + # 'vector') marks a channel axis. + _SPATIAL_KINDS = {"domain", "space"} + ch_axes = [ + idx + for idx, k in enumerate(header.get("kinds", [])) + if k.lower() not in _SPATIAL_KINDS + ] + if ch_axes: + ch_ax = ch_axes[0] + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ch_ax + sp_shape = list(header[MetaKeys.SPATIAL_SHAPE]) + sp_shape.pop(ch_ax) + header[MetaKeys.SPATIAL_SHAPE] = np.array(sp_shape) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 + ) else: header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim _copy_compatible_dict(header, compatible_meta) @@ -1571,6 +1587,11 @@ def _get_affine(self, header: dict) -> np.ndarray: direction = header["space directions"] origin = header["space origin"] + # pynrrd represents non-spatial axes (e.g. 'list' kind in 4-D NRRD files) as rows + # where every element is NaN. Filter them out so the affine only encodes spatial axes. + valid = ~np.all(np.isnan(direction.astype(float)), axis=1) + direction = direction[valid] + x, y = direction.shape affine_diam = min(x, y) + 1 affine: np.ndarray = np.eye(affine_diam) @@ -1609,4 +1630,6 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] + if "kinds" in header: + header["kinds"] = header["kinds"][::-1] return header diff --git a/tests/data/test_nrrd_reader.py b/tests/data/test_nrrd_reader.py index 5bf958e970..b2de33f400 100644 --- a/tests/data/test_nrrd_reader.py +++ b/tests/data/test_nrrd_reader.py @@ -44,6 +44,23 @@ "space origin": [1.0, 5.0, 20.0], }, ] +# 4-D NRRD with an explicit 'list' channel axis (kinds: list domain domain domain). +# pynrrd stores the 'none' space direction for the channel axis as a row of NaN values. +TEST_CASE_4D_CHANNEL = [ + (3, 4, 5, 6), # (channel, H, W, D) + "test_4d_channel.nrrd", + np.float32, + { + "dimension": 4, + "space": "left-posterior-superior", + "kinds": ["list", "domain", "domain", "domain"], + "sizes": [3, 4, 5, 6], + "space directions": np.array( + [[np.nan, np.nan, np.nan], [1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]] + ), + "space origin": np.array([10.0, 20.0, 30.0]), + }, +] @skipUnless(has_nrrd, "nrrd required") @@ -128,6 +145,32 @@ def test_read_with_header_index_order_c(self, data_shape, filename, expected_sha self.assertTupleEqual(image_array.shape, expected_shape[::-1]) self.assertTupleEqual(image_array.shape, tuple(image_header["spatial_shape"])) + @parameterized.expand([TEST_CASE_4D_CHANNEL]) + def test_read_4d_channel(self, data_shape, filename, dtype, reference_header): + """4-D NRRD with a 'list' channel axis must not crash in _get_affine and must + set ORIGINAL_CHANNEL_DIM / spatial_shape correctly.""" + test_image = np.random.rand(*data_shape).astype(dtype) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, filename) + nrrd.write(filepath, test_image, header=reference_header) + reader = NrrdReader() + image_array, image_header = reader.get_data(reader.read(filepath)) + self.assertIsInstance(image_array, np.ndarray) + self.assertEqual(image_array.dtype, dtype) + self.assertTupleEqual(image_array.shape, data_shape) + # spatial_shape must exclude the channel axis + self.assertTupleEqual(tuple(image_header["spatial_shape"]), data_shape[1:]) + # channel dim 0 must be identified + self.assertEqual(image_header["original_channel_dim"], 0) + # affine must be a valid 4×4 matrix (3 spatial dims → 4×4) + self.assertTupleEqual(image_header["affine"].shape, (4, 4)) + np.testing.assert_allclose( + image_header["affine"], + np.array( + [[-1.0, 0.0, 0.0, -10.0], [0.0, -2.0, 0.0, -20.0], [0.0, 0.0, 3.0, 30.0], [0.0, 0.0, 0.0, 1.0]] + ), + ) + if __name__ == "__main__": unittest.main()