diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 745513fec0..2b6212fb94 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -162,6 +162,8 @@ def __init__( gamma: float = 0.5, delta: float = 0.7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, + use_sigmoid: bool = False, ): """ Args: @@ -170,8 +172,14 @@ def __init__( weight : weight for each loss function. Defaults to 0.5. gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5. delta : weight of the background. Defaults to 0.7. - - + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: if True, use softmax to transform the input logits into probabilities. + Defaults to False. Mutually exclusive with ``use_sigmoid``. + use_sigmoid: if True, use sigmoid to transform the input logits into probabilities. + Defaults to False. Mutually exclusive with ``use_softmax``. + When both ``use_softmax`` and ``use_sigmoid`` are False, the input is assumed + to already be probabilities. Example: >>> import torch @@ -182,22 +190,25 @@ def __init__( >>> fl(pred, grnd) """ super().__init__(reduction=LossReduction(reduction).value) + if use_softmax and use_sigmoid: + raise ValueError("use_softmax and use_sigmoid are mutually exclusive.") self.to_onehot_y = to_onehot_y self.num_classes = num_classes self.gamma = gamma self.delta = delta self.weight: float = weight + self.use_softmax = use_softmax + self.use_sigmoid = use_sigmoid self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) - # TODO: Implement this function to support multiple classes segmentation def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. + The input can be raw logits or probabilities depending on ``use_softmax`` + and ``use_sigmoid`` settings. y_true : the shape should be BNH[WD], where N is the number of classes. It only supports binary segmentation. @@ -213,6 +224,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + # Apply activation BEFORE one_hot encoding, since one_hot uses + # values as scatter indices and raw logits would cause index errors. + if self.use_softmax: + y_pred = torch.softmax(y_pred, dim=1) + elif self.use_sigmoid: + y_pred = torch.sigmoid(y_pred) + if y_pred.shape[1] == 1: y_pred = one_hot(y_pred, num_classes=self.num_classes) y_true = one_hot(y_true, num_classes=self.num_classes) diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 3b868a560e..3fa7354cf2 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -61,6 +61,24 @@ def test_with_cuda(self): print(output) np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + def test_use_sigmoid(self): + loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True) + y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) + y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + result = loss(y_pred, y_true) + self.assertTrue(result.item() >= 0) + + def test_use_softmax(self): + loss = AsymmetricUnifiedFocalLoss(use_softmax=True) + y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]]) + y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) + result = loss(y_pred, y_true) + self.assertTrue(result.item() >= 0) + + def test_mutually_exclusive(self): + with self.assertRaises(ValueError): + AsymmetricUnifiedFocalLoss(use_softmax=True, use_sigmoid=True) + if __name__ == "__main__": unittest.main()