diff --git a/monai/metrics/active_learning_metrics.py b/monai/metrics/active_learning_metrics.py index 7a1654191e..5c51d262ed 100644 --- a/monai/metrics/active_learning_metrics.py +++ b/monai/metrics/active_learning_metrics.py @@ -129,9 +129,7 @@ def compute_variance( y_pred = y_pred.float() if not include_background: - y = y_pred - # TODO If this utils is made to be optional for 'y' it would be nice - y_pred, y = ignore_background(y_pred=y_pred, y=y) + y_pred, _ = ignore_background(y_pred=y_pred) # Set any values below 0 to threshold y_pred[y_pred <= 0] = threshold diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index 3921c220db..3070764e06 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -15,7 +15,7 @@ from collections.abc import Iterable, Sequence from functools import cache, partial from types import ModuleType -from typing import Any +from typing import Any, overload import numpy as np import torch @@ -55,7 +55,17 @@ ] -def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: +@overload +def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]: ... + + +@overload +def ignore_background(y_pred: NdarrayTensor, y: None = ...) -> tuple[NdarrayTensor, None]: ... + + +def ignore_background( + y_pred: NdarrayTensor, y: NdarrayTensor | None = None +) -> tuple[NdarrayTensor, NdarrayTensor | None]: """ This function is used to remove background (the first channel) for `y_pred` and `y`. @@ -63,11 +73,12 @@ def ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayT y_pred: predictions. As for classification tasks, `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks, the shape should be [BNHW] or [BNHWD]. - y: ground truth, the first dim is batch. + y: optional ground truth, the first dim is batch. """ - y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] + if y is not None: + y = y[:, 1:] if y.shape[1] > 1 else y # type: ignore[assignment] y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # type: ignore[assignment] return y_pred, y