From ac7efc5c343f72d30f6936f503787ee1af513c60 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 26 May 2026 19:03:56 +0530 Subject: [PATCH 1/9] Scaffold scalarization Signed-off-by: ppraneth --- src/torchjd/scalarization/__init__.py | 16 ++++++++++ src/torchjd/scalarization/_constant.py | 28 +++++++++++++++++ src/torchjd/scalarization/_mean.py | 12 ++++++++ src/torchjd/scalarization/_random.py | 19 ++++++++++++ src/torchjd/scalarization/_scalarizer_base.py | 30 +++++++++++++++++++ src/torchjd/scalarization/_sum.py | 12 ++++++++ 6 files changed, 117 insertions(+) create mode 100644 src/torchjd/scalarization/__init__.py create mode 100644 src/torchjd/scalarization/_constant.py create mode 100644 src/torchjd/scalarization/_mean.py create mode 100644 src/torchjd/scalarization/_random.py create mode 100644 src/torchjd/scalarization/_scalarizer_base.py create mode 100644 src/torchjd/scalarization/_sum.py diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py new file mode 100644 index 00000000..fdc30c3e --- /dev/null +++ b/src/torchjd/scalarization/__init__.py @@ -0,0 +1,16 @@ +# When a stateful scalarizer is added, move `Stateful` from `torchjd.aggregation._mixins` to +# `torchjd._mixins` so both packages can share it (see issue #666). + +from ._constant import Constant +from ._mean import Mean +from ._random import Random +from ._scalarizer_base import Scalarizer +from ._sum import Sum + +__all__ = [ + "Constant", + "Mean", + "Random", + "Scalarizer", + "Sum", +] diff --git a/src/torchjd/scalarization/_constant.py b/src/torchjd/scalarization/_constant.py new file mode 100644 index 00000000..2a61c832 --- /dev/null +++ b/src/torchjd/scalarization/_constant.py @@ -0,0 +1,28 @@ +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class Constant(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + constant, pre-determined weights. + + :param weights: The weights to apply to the losses. Must have the same shape as the losses + passed at call time. + """ + + def __init__(self, weights: Tensor) -> None: + super().__init__() + self.weights = weights + + def forward(self, losses: Tensor, /) -> Tensor: + if losses.shape != self.weights.shape: + raise ValueError( + f"Parameter `losses` should have shape {tuple(self.weights.shape)} (matching the " + f"shape of the weights). Found `losses.shape = {tuple(losses.shape)}`.", + ) + return (self.weights * losses).sum() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(weights={repr(self.weights)})" diff --git a/src/torchjd/scalarization/_mean.py b/src/torchjd/scalarization/_mean.py new file mode 100644 index 00000000..4addaed2 --- /dev/null +++ b/src/torchjd/scalarization/_mean.py @@ -0,0 +1,12 @@ +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class Mean(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. + """ + + def forward(self, losses: Tensor, /) -> Tensor: + return losses.mean() diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py new file mode 100644 index 00000000..09b1cf5a --- /dev/null +++ b/src/torchjd/scalarization/_random.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor +from torch.nn import functional as F + +from ._scalarizer_base import Scalarizer + + +class Random(Scalarizer): + """ + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of + Random Weighting: A Litmus Test for Multi-Task Learning + `_. + """ + + def forward(self, losses: Tensor, /) -> Tensor: + flat = torch.randn(losses.numel(), device=losses.device, dtype=losses.dtype) + weights = F.softmax(flat, dim=-1).reshape(losses.shape) + return (weights * losses).sum() diff --git a/src/torchjd/scalarization/_scalarizer_base.py b/src/torchjd/scalarization/_scalarizer_base.py new file mode 100644 index 00000000..9b2decd8 --- /dev/null +++ b/src/torchjd/scalarization/_scalarizer_base.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod + +from torch import Tensor, nn + + +class Scalarizer(nn.Module, ABC): + """ + Abstract base class for all scalarizers. Reduces a tensor of losses of any shape into a single + scalar loss that can be passed to :meth:`~torch.Tensor.backward`. + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def forward(self, losses: Tensor, /) -> Tensor: ... + + def __call__(self, losses: Tensor, /) -> Tensor: + """ + Computes the scalar loss from the input tensor of losses and applies all registered hooks. + + param losses: The tensor of losses to scalarize. May be of any shape. + """ + return super().__call__(losses) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def __str__(self) -> str: + return f"{self.__class__.__name__}" diff --git a/src/torchjd/scalarization/_sum.py b/src/torchjd/scalarization/_sum.py new file mode 100644 index 00000000..824b8323 --- /dev/null +++ b/src/torchjd/scalarization/_sum.py @@ -0,0 +1,12 @@ +from torch import Tensor + +from ._scalarizer_base import Scalarizer + + +class Sum(Scalarizer): + """ + class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of losses. + """ + + def forward(self, losses: Tensor, /) -> Tensor: + return losses.sum() From d273b51952d5d82dfe2475bb177363315d287978 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 26 May 2026 19:53:49 +0530 Subject: [PATCH 2/9] add test cases scalarization Signed-off-by: ppraneth --- tests/unit/scalarization/__init__.py | 0 tests/unit/scalarization/_asserts.py | 27 +++++++++ tests/unit/scalarization/_inputs.py | 10 ++++ tests/unit/scalarization/test_constant.py | 58 +++++++++++++++++++ tests/unit/scalarization/test_mean.py | 39 +++++++++++++ tests/unit/scalarization/test_random.py | 42 ++++++++++++++ .../scalarization/test_scalarizer_base.py | 20 +++++++ tests/unit/scalarization/test_sum.py | 39 +++++++++++++ 8 files changed, 235 insertions(+) create mode 100644 tests/unit/scalarization/__init__.py create mode 100644 tests/unit/scalarization/_asserts.py create mode 100644 tests/unit/scalarization/_inputs.py create mode 100644 tests/unit/scalarization/test_constant.py create mode 100644 tests/unit/scalarization/test_mean.py create mode 100644 tests/unit/scalarization/test_random.py create mode 100644 tests/unit/scalarization/test_scalarizer_base.py create mode 100644 tests/unit/scalarization/test_sum.py diff --git a/tests/unit/scalarization/__init__.py b/tests/unit/scalarization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/scalarization/_asserts.py b/tests/unit/scalarization/_asserts.py new file mode 100644 index 00000000..8b6d5fc7 --- /dev/null +++ b/tests/unit/scalarization/_asserts.py @@ -0,0 +1,27 @@ +import torch +from torch import Tensor +from utils.tensors import randperm_ + +from torchjd.scalarization import Scalarizer + + +def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None: + out = scalarizer(losses) + assert out.dim() == 0 + assert out.isfinite() + + +def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None: + leaf = losses.detach().requires_grad_() + out = scalarizer(leaf) + out.backward() + assert leaf.grad is not None + assert leaf.grad.isfinite().all() + + +def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None: + out = scalarizer(losses) + flat = losses.flatten() + permuted = flat[randperm_(flat.numel())].reshape(losses.shape) + out_permuted = scalarizer(permuted) + torch.testing.assert_close(out, out_permuted) diff --git a/tests/unit/scalarization/_inputs.py b/tests/unit/scalarization/_inputs.py new file mode 100644 index 00000000..8bb61984 --- /dev/null +++ b/tests/unit/scalarization/_inputs.py @@ -0,0 +1,10 @@ +from torch import Tensor +from utils.tensors import randn_, tensor_ + +scalar_input: Tensor = tensor_(7.0) +vector_input: Tensor = randn_(5) +matrix_input: Tensor = randn_(3, 4) +tensor_3d_input: Tensor = randn_(2, 3, 4) + +typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input] +all_inputs: list[Tensor] = [scalar_input, *typical_inputs] diff --git a/tests/unit/scalarization/test_constant.py b/tests/unit/scalarization/test_constant.py new file mode 100644 index 00000000..85fd649e --- /dev/null +++ b/tests/unit/scalarization/test_constant.py @@ -0,0 +1,58 @@ +from contextlib import nullcontext as does_not_raise + +import torch +from pytest import mark, raises +from utils.contexts import ExceptionContext +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import Constant + +from ._asserts import assert_grad_flow, assert_returns_scalar + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0, 4.0]) + weights = tensor_([0.1, 0.2, 0.3, 0.4]) + torch.testing.assert_close(Constant(weights)(losses), tensor_(3.0)) + + +@mark.parametrize("shape", [(5,), (3, 4), (2, 3, 4)]) +def test_expected_structure(shape: tuple[int, ...]) -> None: + losses = ones_(shape) + weights = ones_(shape) + assert_returns_scalar(Constant(weights), losses) + + +@mark.parametrize("shape", [(5,), (3, 4), (2, 3, 4)]) +def test_grad_flow(shape: tuple[int, ...]) -> None: + losses = ones_(shape) + weights = ones_(shape) / losses.numel() + assert_grad_flow(Constant(weights), losses) + + +@mark.parametrize( + ["weights_shape", "losses_shape", "expectation"], + [ + ((5,), (5,), does_not_raise()), + ((3, 4), (3, 4), does_not_raise()), + ((), (), does_not_raise()), + ((5,), (4,), raises(ValueError)), + ((5,), (5, 1), raises(ValueError)), + ((3, 4), (4, 3), raises(ValueError)), + ], +) +def test_shape_check( + weights_shape: tuple[int, ...], + losses_shape: tuple[int, ...], + expectation: ExceptionContext, +) -> None: + weights = ones_(weights_shape) + losses = ones_(losses_shape) + with expectation: + _ = Constant(weights)(losses) + + +def test_representations() -> None: + s = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) + assert repr(s) == "Constant(weights=tensor([1., 2.]))" + assert str(s) == "Constant" diff --git a/tests/unit/scalarization/test_mean.py b/tests/unit/scalarization/test_mean.py new file mode 100644 index 00000000..c2fec147 --- /dev/null +++ b/tests/unit/scalarization/test_mean.py @@ -0,0 +1,39 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.tensors import tensor_ + +from torchjd.scalarization import Mean + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import all_inputs, typical_inputs + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0]) + torch.testing.assert_close(Mean()(losses), tensor_(2.0)) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Mean(), losses) + + +@mark.parametrize("losses", typical_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Mean(), losses) + + +@mark.parametrize("losses", typical_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + assert_permutation_invariant(Mean(), losses) + + +def test_representations() -> None: + s = Mean() + assert repr(s) == "Mean()" + assert str(s) == "Mean" diff --git a/tests/unit/scalarization/test_random.py b/tests/unit/scalarization/test_random.py new file mode 100644 index 00000000..76c3b123 --- /dev/null +++ b/tests/unit/scalarization/test_random.py @@ -0,0 +1,42 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.contexts import fork_rng +from utils.tensors import ones_, tensor_ + +from torchjd.scalarization import Random + +from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import typical_inputs + + +@mark.parametrize("losses", typical_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Random(), losses) + + +@mark.parametrize("losses", typical_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Random(), losses) + + +def test_deterministic_under_seed() -> None: + losses = tensor_([1.0, 2.0, 3.0, 4.0]) + scalarizer = Random() + with fork_rng(seed=0): + a = scalarizer(losses) + with fork_rng(seed=0): + b = scalarizer(losses) + torch.testing.assert_close(a, b) + + +def test_weights_sum_to_one() -> None: + # If all losses equal c, then sum(weights * losses) == c when weights sum to 1. + losses = ones_((5,)) * 3.0 + torch.testing.assert_close(Random()(losses), tensor_(3.0)) + + +def test_representations() -> None: + s = Random() + assert repr(s) == "Random()" + assert str(s) == "Random" diff --git a/tests/unit/scalarization/test_scalarizer_base.py b/tests/unit/scalarization/test_scalarizer_base.py new file mode 100644 index 00000000..fbf48b76 --- /dev/null +++ b/tests/unit/scalarization/test_scalarizer_base.py @@ -0,0 +1,20 @@ +from pytest import raises +from torch import Tensor + +from torchjd.scalarization import Scalarizer + + +def test_cannot_instantiate_abstract_base() -> None: + with raises(TypeError): + Scalarizer() # type: ignore[abstract] + + +class _Identity(Scalarizer): + def forward(self, losses: Tensor, /) -> Tensor: + return losses.sum() + + +def test_default_representations() -> None: + s = _Identity() + assert repr(s) == "_Identity()" + assert str(s) == "_Identity" diff --git a/tests/unit/scalarization/test_sum.py b/tests/unit/scalarization/test_sum.py new file mode 100644 index 00000000..4706d008 --- /dev/null +++ b/tests/unit/scalarization/test_sum.py @@ -0,0 +1,39 @@ +import torch +from pytest import mark +from torch import Tensor +from utils.tensors import tensor_ + +from torchjd.scalarization import Sum + +from ._asserts import ( + assert_grad_flow, + assert_permutation_invariant, + assert_returns_scalar, +) +from ._inputs import all_inputs, typical_inputs + + +def test_value() -> None: + losses = tensor_([1.0, 2.0, 3.0]) + torch.testing.assert_close(Sum()(losses), tensor_(6.0)) + + +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + assert_returns_scalar(Sum(), losses) + + +@mark.parametrize("losses", typical_inputs) +def test_grad_flow(losses: Tensor) -> None: + assert_grad_flow(Sum(), losses) + + +@mark.parametrize("losses", typical_inputs) +def test_permutation_invariant(losses: Tensor) -> None: + assert_permutation_invariant(Sum(), losses) + + +def test_representations() -> None: + s = Sum() + assert repr(s) == "Sum()" + assert str(s) == "Sum" From 8e3068d736c7142d468fc7ad45db1dd206c69491 Mon Sep 17 00:00:00 2001 From: ppraneth Date: Tue, 26 May 2026 20:02:13 +0530 Subject: [PATCH 3/9] minor edit fixes Signed-off-by: ppraneth --- src/torchjd/scalarization/_constant.py | 4 ++-- src/torchjd/scalarization/_mean.py | 2 +- src/torchjd/scalarization/_random.py | 2 +- tests/unit/scalarization/test_scalarizer_base.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/scalarization/_constant.py b/src/torchjd/scalarization/_constant.py index 2a61c832..24f61a45 100644 --- a/src/torchjd/scalarization/_constant.py +++ b/src/torchjd/scalarization/_constant.py @@ -5,10 +5,10 @@ class Constant(Scalarizer): """ - :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with constant, pre-determined weights. - :param weights: The weights to apply to the losses. Must have the same shape as the losses + param weights: The weights to apply to the losses. Must have the same shape as the losses passed at call time. """ diff --git a/src/torchjd/scalarization/_mean.py b/src/torchjd/scalarization/_mean.py index 4addaed2..ebdd6c96 100644 --- a/src/torchjd/scalarization/_mean.py +++ b/src/torchjd/scalarization/_mean.py @@ -5,7 +5,7 @@ class Mean(Scalarizer): """ - :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. + class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. """ def forward(self, losses: Tensor, /) -> Tensor: diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py index 09b1cf5a..3fb2d87a 100644 --- a/src/torchjd/scalarization/_random.py +++ b/src/torchjd/scalarization/_random.py @@ -7,7 +7,7 @@ class Random(Scalarizer): """ - :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. diff --git a/tests/unit/scalarization/test_scalarizer_base.py b/tests/unit/scalarization/test_scalarizer_base.py index fbf48b76..fed5dce1 100644 --- a/tests/unit/scalarization/test_scalarizer_base.py +++ b/tests/unit/scalarization/test_scalarizer_base.py @@ -6,7 +6,7 @@ def test_cannot_instantiate_abstract_base() -> None: with raises(TypeError): - Scalarizer() # type: ignore[abstract] + Scalarizer() class _Identity(Scalarizer): From 632c795b015adde48fa78e8239176c2d4f9e01a7 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Wed, 27 May 2026 09:57:27 +0530 Subject: [PATCH 4/9] Update src/torchjd/scalarization/_random.py Co-authored-by: Pierre Quinton --- src/torchjd/scalarization/_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py index 3fb2d87a..e6403a41 100644 --- a/src/torchjd/scalarization/_random.py +++ b/src/torchjd/scalarization/_random.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.nn import functional as F +from torch.nn.functional import softmax from ._scalarizer_base import Scalarizer From cb4f4e5a667c7baebd47cad48df42883f21bca72 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri Date: Wed, 27 May 2026 09:57:39 +0530 Subject: [PATCH 5/9] Update src/torchjd/scalarization/_random.py Co-authored-by: Pierre Quinton --- src/torchjd/scalarization/_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py index e6403a41..9ec1c249 100644 --- a/src/torchjd/scalarization/_random.py +++ b/src/torchjd/scalarization/_random.py @@ -15,5 +15,5 @@ class Random(Scalarizer): def forward(self, losses: Tensor, /) -> Tensor: flat = torch.randn(losses.numel(), device=losses.device, dtype=losses.dtype) - weights = F.softmax(flat, dim=-1).reshape(losses.shape) + weights = softmax(flat, dim=-1).reshape(losses.shape) return (weights * losses).sum() From 6921068d779cba37fdf25d9521871296f90b13db Mon Sep 17 00:00:00 2001 From: ppraneth Date: Wed, 27 May 2026 10:09:10 +0530 Subject: [PATCH 6/9] feedback changes Signed-off-by: ppraneth --- src/torchjd/scalarization/__init__.py | 3 --- src/torchjd/scalarization/_scalarizer_base.py | 3 ++- tests/unit/scalarization/test_constant.py | 16 +++++++-------- tests/unit/scalarization/test_random.py | 4 ++-- .../scalarization/test_scalarizer_base.py | 20 ------------------- 5 files changed, 12 insertions(+), 34 deletions(-) delete mode 100644 tests/unit/scalarization/test_scalarizer_base.py diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index fdc30c3e..b7dcb15d 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -1,6 +1,3 @@ -# When a stateful scalarizer is added, move `Stateful` from `torchjd.aggregation._mixins` to -# `torchjd._mixins` so both packages can share it (see issue #666). - from ._constant import Constant from ._mean import Mean from ._random import Random diff --git a/src/torchjd/scalarization/_scalarizer_base.py b/src/torchjd/scalarization/_scalarizer_base.py index 9b2decd8..4d5ca7b1 100644 --- a/src/torchjd/scalarization/_scalarizer_base.py +++ b/src/torchjd/scalarization/_scalarizer_base.py @@ -13,7 +13,8 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def forward(self, losses: Tensor, /) -> Tensor: ... + def forward(self, losses: Tensor, /) -> Tensor: + """Computes the scalarization from input tensor.""" def __call__(self, losses: Tensor, /) -> Tensor: """ diff --git a/tests/unit/scalarization/test_constant.py b/tests/unit/scalarization/test_constant.py index 85fd649e..0ea15415 100644 --- a/tests/unit/scalarization/test_constant.py +++ b/tests/unit/scalarization/test_constant.py @@ -2,12 +2,14 @@ import torch from pytest import mark, raises +from torch import Tensor from utils.contexts import ExceptionContext from utils.tensors import ones_, tensor_ from torchjd.scalarization import Constant from ._asserts import assert_grad_flow, assert_returns_scalar +from ._inputs import all_inputs def test_value() -> None: @@ -16,17 +18,15 @@ def test_value() -> None: torch.testing.assert_close(Constant(weights)(losses), tensor_(3.0)) -@mark.parametrize("shape", [(5,), (3, 4), (2, 3, 4)]) -def test_expected_structure(shape: tuple[int, ...]) -> None: - losses = ones_(shape) - weights = ones_(shape) +@mark.parametrize("losses", all_inputs) +def test_expected_structure(losses: Tensor) -> None: + weights = ones_(losses.shape) assert_returns_scalar(Constant(weights), losses) -@mark.parametrize("shape", [(5,), (3, 4), (2, 3, 4)]) -def test_grad_flow(shape: tuple[int, ...]) -> None: - losses = ones_(shape) - weights = ones_(shape) / losses.numel() +@mark.parametrize("losses", all_inputs) +def test_grad_flow(losses: Tensor) -> None: + weights = ones_(losses.shape) assert_grad_flow(Constant(weights), losses) diff --git a/tests/unit/scalarization/test_random.py b/tests/unit/scalarization/test_random.py index 76c3b123..02071874 100644 --- a/tests/unit/scalarization/test_random.py +++ b/tests/unit/scalarization/test_random.py @@ -31,9 +31,9 @@ def test_deterministic_under_seed() -> None: def test_weights_sum_to_one() -> None: - # If all losses equal c, then sum(weights * losses) == c when weights sum to 1. + # If all losses equal 1, then sum(weights * losses) == 1 when weights sum to 1. losses = ones_((5,)) * 3.0 - torch.testing.assert_close(Random()(losses), tensor_(3.0)) + torch.testing.assert_close(Random()(losses), tensor_(1.0)) def test_representations() -> None: diff --git a/tests/unit/scalarization/test_scalarizer_base.py b/tests/unit/scalarization/test_scalarizer_base.py deleted file mode 100644 index fed5dce1..00000000 --- a/tests/unit/scalarization/test_scalarizer_base.py +++ /dev/null @@ -1,20 +0,0 @@ -from pytest import raises -from torch import Tensor - -from torchjd.scalarization import Scalarizer - - -def test_cannot_instantiate_abstract_base() -> None: - with raises(TypeError): - Scalarizer() - - -class _Identity(Scalarizer): - def forward(self, losses: Tensor, /) -> Tensor: - return losses.sum() - - -def test_default_representations() -> None: - s = _Identity() - assert repr(s) == "_Identity()" - assert str(s) == "_Identity" From ec4b13731f2397036be9bd3c01ea8492839d432a Mon Sep 17 00:00:00 2001 From: ppraneth Date: Wed, 27 May 2026 10:14:48 +0530 Subject: [PATCH 7/9] minor fix Signed-off-by: ppraneth --- tests/unit/scalarization/test_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/scalarization/test_random.py b/tests/unit/scalarization/test_random.py index 02071874..5f621c39 100644 --- a/tests/unit/scalarization/test_random.py +++ b/tests/unit/scalarization/test_random.py @@ -32,7 +32,7 @@ def test_deterministic_under_seed() -> None: def test_weights_sum_to_one() -> None: # If all losses equal 1, then sum(weights * losses) == 1 when weights sum to 1. - losses = ones_((5,)) * 3.0 + losses = ones_((5,)) torch.testing.assert_close(Random()(losses), tensor_(1.0)) From 586840cdf7b1fa886bac0809aab0d4a32cdade4f Mon Sep 17 00:00:00 2001 From: ppraneth Date: Wed, 27 May 2026 12:39:33 +0530 Subject: [PATCH 8/9] docs add Signed-off-by: ppraneth --- docs/source/docs/scalarization/constant.rst | 7 +++++++ docs/source/docs/scalarization/index.rst | 21 +++++++++++++++++++++ docs/source/docs/scalarization/mean.rst | 7 +++++++ docs/source/docs/scalarization/random.rst | 7 +++++++ docs/source/docs/scalarization/sum.rst | 7 +++++++ docs/source/index.rst | 5 +++++ 6 files changed, 54 insertions(+) create mode 100644 docs/source/docs/scalarization/constant.rst create mode 100644 docs/source/docs/scalarization/index.rst create mode 100644 docs/source/docs/scalarization/mean.rst create mode 100644 docs/source/docs/scalarization/random.rst create mode 100644 docs/source/docs/scalarization/sum.rst diff --git a/docs/source/docs/scalarization/constant.rst b/docs/source/docs/scalarization/constant.rst new file mode 100644 index 00000000..bcbf0217 --- /dev/null +++ b/docs/source/docs/scalarization/constant.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Constant +======== + +.. autoclass:: torchjd.scalarization.Constant + :members: __call__ diff --git a/docs/source/docs/scalarization/index.rst b/docs/source/docs/scalarization/index.rst new file mode 100644 index 00000000..38ff2bf3 --- /dev/null +++ b/docs/source/docs/scalarization/index.rst @@ -0,0 +1,21 @@ +scalarization +============= + +.. automodule:: torchjd.scalarization + :no-members: + +Abstract base class +------------------- + +.. autoclass:: torchjd.scalarization.Scalarizer + :members: __call__ + + +.. toctree:: + :hidden: + :maxdepth: 1 + + mean.rst + sum.rst + constant.rst + random.rst diff --git a/docs/source/docs/scalarization/mean.rst b/docs/source/docs/scalarization/mean.rst new file mode 100644 index 00000000..5a435b98 --- /dev/null +++ b/docs/source/docs/scalarization/mean.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Mean +==== + +.. autoclass:: torchjd.scalarization.Mean + :members: __call__ diff --git a/docs/source/docs/scalarization/random.rst b/docs/source/docs/scalarization/random.rst new file mode 100644 index 00000000..0fffdc0e --- /dev/null +++ b/docs/source/docs/scalarization/random.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Random +====== + +.. autoclass:: torchjd.scalarization.Random + :members: __call__ diff --git a/docs/source/docs/scalarization/sum.rst b/docs/source/docs/scalarization/sum.rst new file mode 100644 index 00000000..8f89702c --- /dev/null +++ b/docs/source/docs/scalarization/sum.rst @@ -0,0 +1,7 @@ +:hide-toc: + +Sum +=== + +.. autoclass:: torchjd.scalarization.Sum + :members: __call__ diff --git a/docs/source/index.rst b/docs/source/index.rst index d8b14f83..20d0b6db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,6 +31,10 @@ Jacobian descent is the aggregator, which maps the Jacobian to an optimization s :doc:`Aggregation `, we provide an overview of the various aggregators available in TorchJD, and their corresponding weightings. +For comparison against simple baselines, the :doc:`Scalarization ` +package provides scalarizers that combine a tensor of losses into a single scalar loss, allowing +standard gradient descent to be used. + A straightforward application of Jacobian descent is multi-task learning, in which the vector of per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our :doc:`MTL example `. @@ -70,4 +74,5 @@ TorchJD is open-source, under MIT License. The source code is available on docs/autogram/index.rst docs/autojac/index.rst docs/aggregation/index.rst + docs/scalarization/index.rst docs/linalg/index.rst From d274c9326ac76e6660a0bedad331508d14cbacce Mon Sep 17 00:00:00 2001 From: ppraneth Date: Wed, 27 May 2026 12:40:08 +0530 Subject: [PATCH 9/9] docs add Signed-off-by: ppraneth --- src/torchjd/scalarization/__init__.py | 20 +++++++++++++++++++ src/torchjd/scalarization/_constant.py | 4 ++-- src/torchjd/scalarization/_mean.py | 2 +- src/torchjd/scalarization/_random.py | 2 +- src/torchjd/scalarization/_scalarizer_base.py | 2 +- src/torchjd/scalarization/_sum.py | 2 +- 6 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/torchjd/scalarization/__init__.py b/src/torchjd/scalarization/__init__.py index b7dcb15d..1c51efce 100644 --- a/src/torchjd/scalarization/__init__.py +++ b/src/torchjd/scalarization/__init__.py @@ -1,3 +1,23 @@ +""" +A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of losses of any shape into a single +scalar loss that can be optimized with standard gradient descent. This is the simple baseline +against which :class:`Aggregators ` are compared: instead of +combining the per-loss gradients via the Jacobian or its Gramian, a +:class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to +:meth:`~torch.Tensor.backward` produces the gradient. + +The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of +losses into a single scalar loss. + +>>> from torch import tensor +>>> from torchjd.scalarization import Mean +>>> +>>> scalarizer = Mean() +>>> losses = tensor([1.0, 2.0, 3.0]) +>>> scalarizer(losses) +tensor(2.) +""" + from ._constant import Constant from ._mean import Mean from ._random import Random diff --git a/src/torchjd/scalarization/_constant.py b/src/torchjd/scalarization/_constant.py index 24f61a45..2a61c832 100644 --- a/src/torchjd/scalarization/_constant.py +++ b/src/torchjd/scalarization/_constant.py @@ -5,10 +5,10 @@ class Constant(Scalarizer): """ - class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with constant, pre-determined weights. - param weights: The weights to apply to the losses. Must have the same shape as the losses + :param weights: The weights to apply to the losses. Must have the same shape as the losses passed at call time. """ diff --git a/src/torchjd/scalarization/_mean.py b/src/torchjd/scalarization/_mean.py index ebdd6c96..4addaed2 100644 --- a/src/torchjd/scalarization/_mean.py +++ b/src/torchjd/scalarization/_mean.py @@ -5,7 +5,7 @@ class Mean(Scalarizer): """ - class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. + :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. """ def forward(self, losses: Tensor, /) -> Tensor: diff --git a/src/torchjd/scalarization/_random.py b/src/torchjd/scalarization/_random.py index 9ec1c249..65984f38 100644 --- a/src/torchjd/scalarization/_random.py +++ b/src/torchjd/scalarization/_random.py @@ -7,7 +7,7 @@ class Random(Scalarizer): """ - class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with + :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. diff --git a/src/torchjd/scalarization/_scalarizer_base.py b/src/torchjd/scalarization/_scalarizer_base.py index 4d5ca7b1..7c6cc7b5 100644 --- a/src/torchjd/scalarization/_scalarizer_base.py +++ b/src/torchjd/scalarization/_scalarizer_base.py @@ -20,7 +20,7 @@ def __call__(self, losses: Tensor, /) -> Tensor: """ Computes the scalar loss from the input tensor of losses and applies all registered hooks. - param losses: The tensor of losses to scalarize. May be of any shape. + :param losses: The tensor of losses to scalarize. May be of any shape. """ return super().__call__(losses) diff --git a/src/torchjd/scalarization/_sum.py b/src/torchjd/scalarization/_sum.py index 824b8323..e5c76467 100644 --- a/src/torchjd/scalarization/_sum.py +++ b/src/torchjd/scalarization/_sum.py @@ -5,7 +5,7 @@ class Sum(Scalarizer): """ - class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of losses. + :class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of losses. """ def forward(self, losses: Tensor, /) -> Tensor: