diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 02975039b3..47cc94838a 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -122,15 +122,15 @@ class mean median max 5percentile 95percentile notnans # add the average value of all classes to v if class_labels is None: - class_labels = ["class" + str(i) for i in range(v.shape[1])] + labels = ["class" + str(i) for i in range(v.shape[1])] else: - class_labels = [str(i) for i in class_labels] # ensure to have a list of str + labels = [str(i) for i in class_labels] # ensure to have a list of str - class_labels += ["mean"] + labels += ["mean"] v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1) with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f: - f.write(f"filename{deli}{deli.join(class_labels)}\n") + f.write(f"filename{deli}{deli.join(labels)}\n") for i, b in enumerate(v): f.write( f"{images[i] if images is not None else str(i)}{deli}" @@ -164,7 +164,7 @@ def _compute_op(op: str, d: np.ndarray) -> Any: with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f: f.write(f"class{deli}{deli.join(ops)}\n") for i, c in enumerate(np.transpose(v)): - f.write(f"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n") + f.write(f"{labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n") def from_engine(keys: KeysCollection, first: bool = False) -> Callable: diff --git a/tests/handlers/test_write_metrics_reports.py b/tests/handlers/test_write_metrics_reports.py index 1013f15d85..07cf46c122 100644 --- a/tests/handlers/test_write_metrics_reports.py +++ b/tests/handlers/test_write_metrics_reports.py @@ -63,6 +63,28 @@ def test_content(self): self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + def test_multi_metric_details_headers(self): + with tempfile.TemporaryDirectory() as tempdir: + write_metrics_reports( + save_dir=Path(tempdir), + images=["img1", "img2"], + metrics=None, + metric_details={ + "m1": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "m2": torch.tensor([[7, 8], [9, 10]]), + "m3": torch.tensor([[11, 12, 13, 14], [15, 16, 17, 18]]), + }, + summary_ops=None, + deli=",", + output_type="csv", + ) + for name, nclass in [("m1", 3), ("m2", 2), ("m3", 4)]: + path = os.path.join(tempdir, f"{name}_raw.csv") + self.assertTrue(os.path.exists(path)) + with open(path) as f: + header = f.readline().strip().split(",") + self.assertEqual(header, ["filename"] + [f"class{i}" for i in range(nclass)] + ["mean"]) + if __name__ == "__main__": unittest.main()