diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9467b4ef..c683a529 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,21 +1,43 @@ from collections.abc import Iterable +from typing import overload import torch from torch import Tensor from torchjd.aggregation import Aggregator +from torchjd.aggregation._aggregator_bases import WeightedAggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + /, + aggregator: WeightedAggregator, + *, + retain_jac: bool = False, +) -> Tensor: ... + + +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + /, + aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order + *, + retain_jac: bool = False, +) -> None: ... + + def jac_to_grad( tensors: Iterable[Tensor], /, aggregator: Aggregator, *, retain_jac: bool = False, -) -> None: +) -> Tensor | None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result into their ``.grad`` fields. @@ -25,6 +47,9 @@ def jac_to_grad( :param aggregator: The aggregator used to reduce the Jacobians into gradients. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. + :returns: If ``aggregator`` is based on a + :class:`Weighting ` to combine the rows of + the Jacobians, returns the weights used for the aggregation, otherwise returns ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -48,12 +73,15 @@ def jac_to_grad( >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2]) # param now has a .jac field - >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + >>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field >>> param.grad - tensor([-1., 1.]) + tensor([0.5000, 2.5000]) + >>> weights + tensor([0.5, 0.5]) The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of - :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the + weights used to combine the Jacobian are equal because there was no conflict. """ tensors_ = list[TensorWithJac]() @@ -66,7 +94,7 @@ def jac_to_grad( tensors_.append(t) if len(tensors_) == 0: - return + raise ValueError("The field `tensor` cannot be empty.") jacobians = [t.jac for t in tensors_] @@ -76,9 +104,15 @@ def jac_to_grad( _free_jacs(tensors_) jacobian_matrix = _unite_jacobians(jacobians) - gradient_vector = aggregator(jacobian_matrix) + if isinstance(aggregator, WeightedAggregator): + weights = aggregator.weighting(jacobian_matrix) + gradient_vector = weights @ jacobian_matrix + else: + weights = None + gradient_vector = aggregator(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) + return weights def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 04ca3ac2..21f06830 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -3,6 +3,7 @@ the obtained `.grad` field. """ +from torch.testing import assert_close from utils.asserts import assert_grad_close @@ -17,6 +18,7 @@ def test_jac_to_grad() -> None: y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() backward([y1, y2]) # param now has a .jac field - jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + weights = jac_to_grad([param], UPGrad()) # param now has a .grad field assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + assert_close(weights, torch.tensor([0.5, 0.5]), rtol=0.0, atol=0.0) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index b8ea5c6c..0456e5a5 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -1,14 +1,25 @@ from pytest import mark, raises +from torch.testing import assert_close from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac from utils.tensors import tensor_ -from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad +from torchjd.aggregation import ( + Aggregator, + ConFIG, + Mean, + PCGrad, + UPGrad, +) +from torchjd.aggregation._aggregator_bases import WeightedAggregator from torchjd.autojac._jac_to_grad import jac_to_grad -@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) +@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()]) def test_various_aggregators(aggregator: Aggregator) -> None: - """Tests that jac_to_grad works for various aggregators.""" + """ + Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights + should also be returned. For the others, None should be returned. + """ t1 = tensor_(1.0, requires_grad=True) t2 = tensor_([2.0, 3.0], requires_grad=True) @@ -19,11 +30,18 @@ def test_various_aggregators(aggregator: Aggregator) -> None: g1 = expected_grad[0] g2 = expected_grad[1:] - jac_to_grad([t1, t2], aggregator) + optional_weights = jac_to_grad([t1, t2], aggregator) assert_grad_close(t1, g1) assert_grad_close(t2, g2) + if isinstance(aggregator, WeightedAggregator): + assert optional_weights is not None + expected_weights = aggregator.weighting(jac) + assert_close(optional_weights, expected_weights) + else: + assert optional_weights is None + def test_single_tensor() -> None: """Tests that jac_to_grad works when a single tensor is provided.""" @@ -82,7 +100,8 @@ def test_row_mismatch() -> None: def test_no_tensors() -> None: """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" - jac_to_grad([], aggregator=UPGrad()) + with raises(ValueError): + jac_to_grad([], UPGrad()) @mark.parametrize("retain_jac", [True, False])