Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/constant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Constant
========

.. autoclass:: torchjd.scalarization.Constant
:members: __call__
21 changes: 21 additions & 0 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
scalarization
=============

.. automodule:: torchjd.scalarization
:no-members:

Abstract base class
-------------------

.. autoclass:: torchjd.scalarization.Scalarizer
:members: __call__


.. toctree::
:hidden:
:maxdepth: 1

mean.rst
sum.rst
constant.rst
random.rst
Comment on lines +18 to +21
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's list them in alphabetical order (similar to what we did for aggregators).

Suggested change
mean.rst
sum.rst
constant.rst
random.rst
constant.rst
mean.rst
random.rst
sum.rst

7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/mean.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Mean
====

.. autoclass:: torchjd.scalarization.Mean
:members: __call__
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/random.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Random
======

.. autoclass:: torchjd.scalarization.Random
:members: __call__
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/sum.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

Sum
===

.. autoclass:: torchjd.scalarization.Sum
:members: __call__
5 changes: 5 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Jacobian descent is the aggregator, which maps the Jacobian to an optimization s
:doc:`Aggregation <docs/aggregation/index>`, we provide an overview of the various aggregators
available in TorchJD, and their corresponding weightings.

For comparison against simple baselines, the :doc:`Scalarization <docs/scalarization/index>`
package provides scalarizers that combine a tensor of losses into a single scalar loss, allowing
standard gradient descent to be used.

A straightforward application of Jacobian descent is multi-task learning, in which the vector of
per-task losses has to be minimized. To start using TorchJD for multi-task learning, follow our
:doc:`MTL example <examples/mtl>`.
Expand Down Expand Up @@ -70,4 +74,5 @@ TorchJD is open-source, under MIT License. The source code is available on
docs/autogram/index.rst
docs/autojac/index.rst
docs/aggregation/index.rst
docs/scalarization/index.rst
docs/linalg/index.rst
33 changes: 33 additions & 0 deletions src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
A :class:`~torchjd.scalarization.Scalarizer` reduces a tensor of losses of any shape into a single
scalar loss that can be optimized with standard gradient descent. This is the simple baseline
against which :class:`Aggregators <torchjd.aggregation.Aggregator>` are compared: instead of
combining the per-loss gradients via the Jacobian or its Gramian, a
:class:`~torchjd.scalarization.Scalarizer` combines the losses directly, and a standard call to
:meth:`~torch.Tensor.backward` produces the gradient.
The following example shows how to use :class:`~torchjd.scalarization.Mean` to combine a vector of
losses into a single scalar loss.
>>> from torch import tensor
>>> from torchjd.scalarization import Mean
>>>
>>> scalarizer = Mean()
>>> losses = tensor([1.0, 2.0, 3.0])
>>> scalarizer(losses)
tensor(2.)
Comment on lines +17 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even name the output to make this even more explicit.

Suggested change
>>> scalarizer(losses)
tensor(2.)
>>> loss = scalarizer(losses)
>>> loss
tensor(2.)

"""

from ._constant import Constant
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._sum import Sum

__all__ = [
"Constant",
"Mean",
"Random",
"Scalarizer",
"Sum",
]
28 changes: 28 additions & 0 deletions src/torchjd/scalarization/_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from torch import Tensor

from ._scalarizer_base import Scalarizer


class Constant(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with
constant, pre-determined weights.

:param weights: The weights to apply to the losses. Must have the same shape as the losses
passed at call time.
"""

def __init__(self, weights: Tensor) -> None:
super().__init__()
self.weights = weights

def forward(self, losses: Tensor, /) -> Tensor:
if losses.shape != self.weights.shape:
raise ValueError(
f"Parameter `losses` should have shape {tuple(self.weights.shape)} (matching the "
f"shape of the weights). Found `losses.shape = {tuple(losses.shape)}`.",
)
return (self.weights * losses).sum()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(weights={repr(self.weights)})"
12 changes: 12 additions & 0 deletions src/torchjd/scalarization/_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import Tensor

from ._scalarizer_base import Scalarizer


class Mean(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that returns the mean of the input tensor of losses.
"""

def forward(self, losses: Tensor, /) -> Tensor:
return losses.mean()
19 changes: 19 additions & 0 deletions src/torchjd/scalarization/_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax

from ._scalarizer_base import Scalarizer


class Random(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of losses with
positive random weights summing to 1, as defined in Algorithm 2 of `Reasonable Effectiveness of
Random Weighting: A Litmus Test for Multi-Task Learning
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""

def forward(self, losses: Tensor, /) -> Tensor:
flat = torch.randn(losses.numel(), device=losses.device, dtype=losses.dtype)
weights = softmax(flat, dim=-1).reshape(losses.shape)
return (weights * losses).sum()
31 changes: 31 additions & 0 deletions src/torchjd/scalarization/_scalarizer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod

from torch import Tensor, nn


class Scalarizer(nn.Module, ABC):
"""
Abstract base class for all scalarizers. Reduces a tensor of losses of any shape into a single
scalar loss that can be passed to :meth:`~torch.Tensor.backward`.
"""
Comment on lines +7 to +10
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ValerianRey I think I would abstract away from losses and differentiation to tensors/inputs and making them into scalars. What do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think the names values (for the input) and value (for the output) would be more appropriate than the name losses (or loss) in the parameters and docstrings of the scalarizers.

Something like:

  • losses => values
  • "tensor of losses" => "tensor of values"
  • "scalar loss" => "scalar value"
  • etc.

I would also remove the "that can be passed to :meth:~torch.Tensor.backward."

This will make Scalarizer a bit independent of the concept of losses and optimization, to make them strictly about scalarizing tensors.

However, when we describe the usage of scalarizers, we should still talk about losses and optimization (e.g. docs/sources/index.rst is fine like this IMO, and same for the src/torchjd/scalarization/__init__.py, except the very first sentence which should be made more general (talk about values and don't talk about optimization there).


def __init__(self) -> None:
super().__init__()

@abstractmethod
def forward(self, losses: Tensor, /) -> Tensor:
"""Computes the scalarization from input tensor."""

def __call__(self, losses: Tensor, /) -> Tensor:
"""
Computes the scalar loss from the input tensor of losses and applies all registered hooks.

:param losses: The tensor of losses to scalarize. May be of any shape.
"""
return super().__call__(losses)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

def __str__(self) -> str:
return f"{self.__class__.__name__}"
12 changes: 12 additions & 0 deletions src/torchjd/scalarization/_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch import Tensor

from ._scalarizer_base import Scalarizer


class Sum(Scalarizer):
"""
:class:`~torchjd.scalarization.Scalarizer` that returns the sum of the input tensor of losses.
"""

def forward(self, losses: Tensor, /) -> Tensor:
return losses.sum()
Empty file.
27 changes: 27 additions & 0 deletions tests/unit/scalarization/_asserts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch import Tensor
from utils.tensors import randperm_

from torchjd.scalarization import Scalarizer


def assert_returns_scalar(scalarizer: Scalarizer, losses: Tensor) -> None:
out = scalarizer(losses)
assert out.dim() == 0
assert out.isfinite()


def assert_grad_flow(scalarizer: Scalarizer, losses: Tensor) -> None:
leaf = losses.detach().requires_grad_()
out = scalarizer(leaf)
out.backward()
assert leaf.grad is not None
assert leaf.grad.isfinite().all()


def assert_permutation_invariant(scalarizer: Scalarizer, losses: Tensor) -> None:
out = scalarizer(losses)
flat = losses.flatten()
permuted = flat[randperm_(flat.numel())].reshape(losses.shape)
out_permuted = scalarizer(permuted)
torch.testing.assert_close(out, out_permuted)
10 changes: 10 additions & 0 deletions tests/unit/scalarization/_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torch import Tensor
from utils.tensors import randn_, tensor_

scalar_input: Tensor = tensor_(7.0)
vector_input: Tensor = randn_(5)
matrix_input: Tensor = randn_(3, 4)
tensor_3d_input: Tensor = randn_(2, 3, 4)

typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input]
all_inputs: list[Tensor] = [scalar_input, *typical_inputs]
Comment on lines +4 to +10
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scalar_input: Tensor = tensor_(7.0)
vector_input: Tensor = randn_(5)
matrix_input: Tensor = randn_(3, 4)
tensor_3d_input: Tensor = randn_(2, 3, 4)
typical_inputs: list[Tensor] = [vector_input, matrix_input, tensor_3d_input]
all_inputs: list[Tensor] = [scalar_input, *typical_inputs]
scalar_inputs: Tensor = [randn_([]) for _ in range(3)]
vector_inputs: Tensor = [randn_([5]) for _ in range(3)]
matrix_inputs: Tensor = [randn_([3, 4]) for _ in range(3)]
tensor_3d_inputs: Tensor = [randn_([2, 3, 4]) for _ in range(3)]
typical_inputs: list[Tensor] = vector_inputs + matrix_inputs + tensor_3d_inputs
all_inputs: list[Tensor] = scalar_input + typical_inputs

Maybe we should even have many shapes instead, what do you think @ValerianRey ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename typical_inputs to non_scalar_inputs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should even have many shapes instead, what do you think @ValerianRey ?

In a future PR I would say. What you suggested is good enough.

58 changes: 58 additions & 0 deletions tests/unit/scalarization/test_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark, raises
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_, tensor_

from torchjd.scalarization import Constant

from ._asserts import assert_grad_flow, assert_returns_scalar
from ._inputs import all_inputs


def test_value() -> None:
losses = tensor_([1.0, 2.0, 3.0, 4.0])
weights = tensor_([0.1, 0.2, 0.3, 0.4])
torch.testing.assert_close(Constant(weights)(losses), tensor_(3.0))


@mark.parametrize("losses", all_inputs)
def test_expected_structure(losses: Tensor) -> None:
weights = ones_(losses.shape)
assert_returns_scalar(Constant(weights), losses)


@mark.parametrize("losses", all_inputs)
def test_grad_flow(losses: Tensor) -> None:
weights = ones_(losses.shape)
assert_grad_flow(Constant(weights), losses)


@mark.parametrize(
["weights_shape", "losses_shape", "expectation"],
[
((5,), (5,), does_not_raise()),
((3, 4), (3, 4), does_not_raise()),
((), (), does_not_raise()),
((5,), (4,), raises(ValueError)),
((5,), (5, 1), raises(ValueError)),
((3, 4), (4, 3), raises(ValueError)),
],
)
def test_shape_check(
weights_shape: tuple[int, ...],
losses_shape: tuple[int, ...],
expectation: ExceptionContext,
) -> None:
weights = ones_(weights_shape)
losses = ones_(losses_shape)
with expectation:
_ = Constant(weights)(losses)


def test_representations() -> None:
s = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
assert repr(s) == "Constant(weights=tensor([1., 2.]))"
assert str(s) == "Constant"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for Constant aggregator, the str is: Constant([1., 2.]) or something. It would be nice to have the same thing here. To do that, you should use the pref_vector_to_str_suffix function, defined in TorchJD/src/torchjd/aggregation/_utils/pref_vector.py.

Since it won't be specific to aggregation anymore, you should move this function and the helper vector_to_str function to a new file TorchJD/src/_vector_str.py. We could then remove entirely the file TorchJD/src/aggregation/_utils/str.py, which will be empty.

39 changes: 39 additions & 0 deletions tests/unit/scalarization/test_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from pytest import mark
from torch import Tensor
from utils.tensors import tensor_

from torchjd.scalarization import Mean

from ._asserts import (
assert_grad_flow,
assert_permutation_invariant,
assert_returns_scalar,
)
from ._inputs import all_inputs, typical_inputs


def test_value() -> None:
losses = tensor_([1.0, 2.0, 3.0])
torch.testing.assert_close(Mean()(losses), tensor_(2.0))


@mark.parametrize("losses", all_inputs)
def test_expected_structure(losses: Tensor) -> None:
assert_returns_scalar(Mean(), losses)


@mark.parametrize("losses", typical_inputs)
def test_grad_flow(losses: Tensor) -> None:
assert_grad_flow(Mean(), losses)


@mark.parametrize("losses", typical_inputs)
def test_permutation_invariant(losses: Tensor) -> None:
assert_permutation_invariant(Mean(), losses)


def test_representations() -> None:
s = Mean()
assert repr(s) == "Mean()"
assert str(s) == "Mean"
42 changes: 42 additions & 0 deletions tests/unit/scalarization/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from pytest import mark
from torch import Tensor
from utils.contexts import fork_rng
from utils.tensors import ones_, tensor_

from torchjd.scalarization import Random

from ._asserts import assert_grad_flow, assert_returns_scalar
from ._inputs import typical_inputs


@mark.parametrize("losses", typical_inputs)
def test_expected_structure(losses: Tensor) -> None:
assert_returns_scalar(Random(), losses)


@mark.parametrize("losses", typical_inputs)
def test_grad_flow(losses: Tensor) -> None:
assert_grad_flow(Random(), losses)


def test_deterministic_under_seed() -> None:
losses = tensor_([1.0, 2.0, 3.0, 4.0])
scalarizer = Random()
with fork_rng(seed=0):
a = scalarizer(losses)
with fork_rng(seed=0):
b = scalarizer(losses)
torch.testing.assert_close(a, b)


def test_weights_sum_to_one() -> None:
# If all losses equal 1, then sum(weights * losses) == 1 when weights sum to 1.
losses = ones_((5,))
torch.testing.assert_close(Random()(losses), tensor_(1.0))


def test_representations() -> None:
s = Random()
assert repr(s) == "Random()"
assert str(s) == "Random"
Loading
Loading