-
Notifications
You must be signed in to change notification settings - Fork 17
feat(scalarization): Add scalarization package #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ac7efc5
d273b51
8e3068d
632c795
cb4f4e5
6921068
ec4b137
586840c
d274c93
80257e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Constant | ||
| ======== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Constant | ||
| :members: __call__ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| scalarization | ||
| ============= | ||
|
|
||
| .. automodule:: torchjd.scalarization | ||
| :no-members: | ||
|
|
||
| Abstract base class | ||
| ------------------- | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Scalarizer | ||
| :members: __call__ | ||
|
|
||
|
|
||
| .. toctree:: | ||
| :hidden: | ||
| :maxdepth: 1 | ||
|
|
||
| mean.rst | ||
| sum.rst | ||
| constant.rst | ||
| random.rst | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Mean | ||
| ==== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Mean | ||
| :members: __call__ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Random | ||
| ====== | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Random | ||
| :members: __call__ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| Sum | ||
| === | ||
|
|
||
| .. autoclass:: torchjd.scalarization.Sum | ||
| :members: __call__ |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,33 @@ | ||||||||||||
| """ | ||||||||||||
| A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of losses of any shape into a single | ||||||||||||
| scalar loss that can be optimized with standard gradient descent. This is the simple baseline | ||||||||||||
| against which :class:`Aggregators <torchjd.aggregation.Aggregator>` are compared: instead of | ||||||||||||
| combining the per-loss gradients via the Jacobian or its Gramian, a | ||||||||||||
| :class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to | ||||||||||||
| :meth:`~torch.Tensor.backward` produces the gradient. | ||||||||||||
| The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of | ||||||||||||
| losses into a single scalar loss. | ||||||||||||
| >>> from torch import tensor | ||||||||||||
| >>> from torchjd.scalarization import Mean | ||||||||||||
| >>> | ||||||||||||
| >>> scalarizer = Mean() | ||||||||||||
| >>> losses = tensor([1.0, 2.0, 3.0]) | ||||||||||||
| >>> scalarizer(losses) | ||||||||||||
| tensor(2.) | ||||||||||||
|
Comment on lines
+17
to
+18
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would even name the output to make this even more explicit.
Suggested change
|
||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| from ._constant import Constant | ||||||||||||
| from ._mean import Mean | ||||||||||||
| from ._random import Random | ||||||||||||
| from ._scalarizer_base import Scalarizer | ||||||||||||
| from ._sum import Sum | ||||||||||||
|
|
||||||||||||
| __all__ = [ | ||||||||||||
| "Constant", | ||||||||||||
| "Mean", | ||||||||||||
| "Random", | ||||||||||||
| "Scalarizer", | ||||||||||||
| "Sum", | ||||||||||||
| ] | ||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from torch import Tensor | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Constant(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with | ||
| constant, pre-determined weights. | ||
|
|
||
| :param weights: The weights to apply to the losses. Must have the same shape as the losses | ||
| passed at call time. | ||
| """ | ||
|
|
||
| def __init__(self, weights: Tensor) -> None: | ||
| super().__init__() | ||
| self.weights = weights | ||
|
|
||
| def forward(self, losses: Tensor, /) -> Tensor: | ||
| if losses.shape != self.weights.shape: | ||
| raise ValueError( | ||
| f"Parameter `losses` should have shape {tuple(self.weights.shape)} (matching the " | ||
| f"shape of the weights). Found `losses.shape = {tuple(losses.shape)}`.", | ||
| ) | ||
| return (self.weights * losses).sum() | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(weights={repr(self.weights)})" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from torch import Tensor | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Mean(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses. | ||
| """ | ||
|
|
||
| def forward(self, losses: Tensor, /) -> Tensor: | ||
| return losses.mean() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import torch | ||
| from torch import Tensor | ||
| from torch.nn.functional import softmax | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Random(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with | ||
| positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of | ||
| Random Weighting: A Litmus Test for Multi-Task Learning | ||
| <https://arxiv.org/pdf/2111.10603.pdf>`_. | ||
| """ | ||
|
|
||
| def forward(self, losses: Tensor, /) -> Tensor: | ||
| flat = torch.randn(losses.numel(), device=losses.device, dtype=losses.dtype) | ||
| weights = softmax(flat, dim=-1).reshape(losses.shape) | ||
| return (weights * losses).sum() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| from torch import Tensor, nn | ||
|
|
||
|
|
||
| class Scalarizer(nn.Module, ABC): | ||
| """ | ||
| Abstract base class for all scalarizers. Reduces a tensor of losses of any shape into a single | ||
| scalar loss that can be passed to :meth:`~torch.Tensor.backward`. | ||
| """ | ||
|
Comment on lines
+7
to
+10
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ValerianRey I think I would abstract away from
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. I think the names Something like:
I would also remove the "that can be passed to :meth: This will make Scalarizer a bit independent of the concept of losses and optimization, to make them strictly about scalarizing tensors. However, when we describe the usage of scalarizers, we should still talk about losses and optimization (e.g. docs/sources/index.rst is fine like this IMO, and same for the |
||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
|
|
||
| @abstractmethod | ||
| def forward(self, losses: Tensor, /) -> Tensor: | ||
| """Computes the scalarization from input tensor.""" | ||
|
|
||
| def __call__(self, losses: Tensor, /) -> Tensor: | ||
| """ | ||
| Computes the scalar loss from the input tensor of losses and applies all registered hooks. | ||
|
|
||
| :param losses: The tensor of losses to scalarize. May be of any shape. | ||
| """ | ||
| return super().__call__(losses) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}()" | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"{self.__class__.__name__}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from torch import Tensor | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class Sum(Scalarizer): | ||
| """ | ||
| :class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of losses. | ||
| """ | ||
|
|
||
| def forward(self, losses: Tensor, /) -> Tensor: | ||
| return losses.sum() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import torch | ||
| from torch import Tensor | ||
| from utils.tensors import randperm_ | ||
|
|
||
| from torchjd.scalarization import Scalarizer | ||
|
|
||
|
|
||
| def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| out = scalarizer(losses) | ||
| assert out.dim() == 0 | ||
| assert out.isfinite() | ||
|
|
||
|
|
||
| def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| leaf = losses.detach().requires_grad_() | ||
| out = scalarizer(leaf) | ||
| out.backward() | ||
| assert leaf.grad is not None | ||
| assert leaf.grad.isfinite().all() | ||
|
|
||
|
|
||
| def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None: | ||
| out = scalarizer(losses) | ||
| flat = losses.flatten() | ||
| permuted = flat[randperm_(flat.numel())].reshape(losses.shape) | ||
| out_permuted = scalarizer(permuted) | ||
| torch.testing.assert_close(out, out_permuted) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,10 @@ | ||||||||||||||||||||||||||||||
| from torch import Tensor | ||||||||||||||||||||||||||||||
| from utils.tensors import randn_, tensor_ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| scalar_input: Tensor = tensor_(7.0) | ||||||||||||||||||||||||||||||
| vector_input: Tensor = randn_(5) | ||||||||||||||||||||||||||||||
| matrix_input: Tensor = randn_(3, 4) | ||||||||||||||||||||||||||||||
| tensor_3d_input: Tensor = randn_(2, 3, 4) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input] | ||||||||||||||||||||||||||||||
| all_inputs: list[Tensor] = [scalar_input, *typical_inputs] | ||||||||||||||||||||||||||||||
|
Comment on lines
+4
to
+10
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Maybe we should even have many shapes instead, what do you think @ValerianRey ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rename
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In a future PR I would say. What you suggested is good enough. |
||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from contextlib import nullcontext as does_not_raise | ||
|
|
||
| import torch | ||
| from pytest import mark, raises | ||
| from torch import Tensor | ||
| from utils.contexts import ExceptionContext | ||
| from utils.tensors import ones_, tensor_ | ||
|
|
||
| from torchjd.scalarization import Constant | ||
|
|
||
| from ._asserts import assert_grad_flow, assert_returns_scalar | ||
| from ._inputs import all_inputs | ||
|
|
||
|
|
||
| def test_value() -> None: | ||
| losses = tensor_([1.0, 2.0, 3.0, 4.0]) | ||
| weights = tensor_([0.1, 0.2, 0.3, 0.4]) | ||
| torch.testing.assert_close(Constant(weights)(losses), tensor_(3.0)) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", all_inputs) | ||
| def test_expected_structure(losses: Tensor) -> None: | ||
| weights = ones_(losses.shape) | ||
| assert_returns_scalar(Constant(weights), losses) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", all_inputs) | ||
| def test_grad_flow(losses: Tensor) -> None: | ||
| weights = ones_(losses.shape) | ||
| assert_grad_flow(Constant(weights), losses) | ||
|
|
||
|
|
||
| @mark.parametrize( | ||
| ["weights_shape", "losses_shape", "expectation"], | ||
| [ | ||
| ((5,), (5,), does_not_raise()), | ||
| ((3, 4), (3, 4), does_not_raise()), | ||
| ((), (), does_not_raise()), | ||
| ((5,), (4,), raises(ValueError)), | ||
| ((5,), (5, 1), raises(ValueError)), | ||
| ((3, 4), (4, 3), raises(ValueError)), | ||
| ], | ||
| ) | ||
| def test_shape_check( | ||
| weights_shape: tuple[int, ...], | ||
| losses_shape: tuple[int, ...], | ||
| expectation: ExceptionContext, | ||
| ) -> None: | ||
| weights = ones_(weights_shape) | ||
| losses = ones_(losses_shape) | ||
| with expectation: | ||
| _ = Constant(weights)(losses) | ||
|
|
||
|
|
||
| def test_representations() -> None: | ||
| s = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) | ||
| assert repr(s) == "Constant(weights=tensor([1., 2.]))" | ||
| assert str(s) == "Constant" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for Constant aggregator, the str is: Constant([1., 2.]) or something. It would be nice to have the same thing here. To do that, you should use the Since it won't be specific to aggregation anymore, you should move this function and the helper |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import torch | ||
| from pytest import mark | ||
| from torch import Tensor | ||
| from utils.tensors import tensor_ | ||
|
|
||
| from torchjd.scalarization import Mean | ||
|
|
||
| from ._asserts import ( | ||
| assert_grad_flow, | ||
| assert_permutation_invariant, | ||
| assert_returns_scalar, | ||
| ) | ||
| from ._inputs import all_inputs, typical_inputs | ||
|
|
||
|
|
||
| def test_value() -> None: | ||
| losses = tensor_([1.0, 2.0, 3.0]) | ||
| torch.testing.assert_close(Mean()(losses), tensor_(2.0)) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", all_inputs) | ||
| def test_expected_structure(losses: Tensor) -> None: | ||
| assert_returns_scalar(Mean(), losses) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", typical_inputs) | ||
| def test_grad_flow(losses: Tensor) -> None: | ||
| assert_grad_flow(Mean(), losses) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", typical_inputs) | ||
| def test_permutation_invariant(losses: Tensor) -> None: | ||
| assert_permutation_invariant(Mean(), losses) | ||
|
|
||
|
|
||
| def test_representations() -> None: | ||
| s = Mean() | ||
| assert repr(s) == "Mean()" | ||
| assert str(s) == "Mean" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import torch | ||
| from pytest import mark | ||
| from torch import Tensor | ||
| from utils.contexts import fork_rng | ||
| from utils.tensors import ones_, tensor_ | ||
|
|
||
| from torchjd.scalarization import Random | ||
|
|
||
| from ._asserts import assert_grad_flow, assert_returns_scalar | ||
| from ._inputs import typical_inputs | ||
|
|
||
|
|
||
| @mark.parametrize("losses", typical_inputs) | ||
| def test_expected_structure(losses: Tensor) -> None: | ||
| assert_returns_scalar(Random(), losses) | ||
|
|
||
|
|
||
| @mark.parametrize("losses", typical_inputs) | ||
| def test_grad_flow(losses: Tensor) -> None: | ||
| assert_grad_flow(Random(), losses) | ||
|
|
||
|
|
||
| def test_deterministic_under_seed() -> None: | ||
| losses = tensor_([1.0, 2.0, 3.0, 4.0]) | ||
| scalarizer = Random() | ||
| with fork_rng(seed=0): | ||
| a = scalarizer(losses) | ||
| with fork_rng(seed=0): | ||
| b = scalarizer(losses) | ||
| torch.testing.assert_close(a, b) | ||
|
|
||
|
|
||
| def test_weights_sum_to_one() -> None: | ||
| # If all losses equal 1, then sum(weights * losses) == 1 when weights sum to 1. | ||
| losses = ones_((5,)) | ||
| torch.testing.assert_close(Random()(losses), tensor_(1.0)) | ||
|
|
||
|
|
||
| def test_representations() -> None: | ||
| s = Random() | ||
| assert repr(s) == "Random()" | ||
| assert str(s) == "Random" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's list them in alphabetical order (similar to what we did for aggregators).