From f09a485cf5cfcbdc3728c8df893a48af4d6aeb0d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 20 Feb 2026 10:36:17 +0100 Subject: [PATCH 1/2] feat(autojac): Make `jac_to_grad` return optional weights. * Change `aggregator: Aggregator` to `method: Aggregator | Weighting` and return type to optional `Tensor`. * Make `method` positional only. * Add overloads to rename `method` to `aggregator` or `weighting` and link it to output type. * Compute the weights if we provide a weighting and return them. * Update the doc and add a usage example --- src/torchjd/autojac/_jac_to_grad.py | 72 +++++++++++++++++++++++--- tests/doc/test_jac_to_grad.py | 19 ++++++- tests/unit/autojac/test_jac_to_grad.py | 2 +- 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9467b4ef..d2c15ab5 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,30 +1,57 @@ from collections.abc import Iterable +from typing import overload import torch from torch import Tensor -from torchjd.aggregation import Aggregator +from torchjd.aggregation import Aggregator, Weighting from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension +@overload def jac_to_grad( tensors: Iterable[Tensor], - /, aggregator: Aggregator, + /, + *, + retain_jac: bool = False, +) -> None: ... + + +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + weighting: Weighting, + /, *, retain_jac: bool = False, -) -> None: +) -> Tensor: ... + + +def jac_to_grad( + tensors: Iterable[Tensor], + method: Aggregator | Weighting, + /, + *, + retain_jac: bool = False, +) -> Tensor | None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result into their ``.grad`` fields. :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses). - :param aggregator: The aggregator used to reduce the Jacobians into gradients. + :param method: The method used to reduce the Jacobians into gradients. Can be an + :class:`Aggregator ` or a + :class:`Weighting ` in which case + ``jac_to_grad`` also returns the weights used to aggregator the Jacobians. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. + :returns: If ``method`` is a + :class:`Weighting `, returns the weights + used to aggregate the Jacobians, otherwise ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -48,9 +75,32 @@ def jac_to_grad( >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2]) # param now has a .jac field - >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + >>> jac_to_grad([param], UPGrad()) # param now has a .grad field >>> param.grad - tensor([-1., 1.]) + tensor([0.5000, 2.5000]) + + The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + + .. admonition:: + Example + + This is the same example as before except that we also obtain the weights + + >>> import torch + >>> + >>> from torchjd.aggregation import UPGradWeighting + >>> from torchjd.autojac import backward, jac_to_grad + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> backward([y1, y2]) + >>> weights = jac_to_grad([param], UPGradWeighting()) + >>> weights + tensor([1., 1.]) The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. @@ -66,7 +116,7 @@ def jac_to_grad( tensors_.append(t) if len(tensors_) == 0: - return + return None jacobians = [t.jac for t in tensors_] @@ -76,9 +126,15 @@ def jac_to_grad( _free_jacs(tensors_) jacobian_matrix = _unite_jacobians(jacobians) - gradient_vector = aggregator(jacobian_matrix) + if isinstance(method, Weighting): + weights = method(jacobian_matrix) + gradient_vector = weights @ jacobian_matrix + else: + weights = None + gradient_vector = method(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) + return weights def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 1f064a6c..dde6cff0 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -3,6 +3,7 @@ the obtained `.grad` field. """ +from torch.testing import assert_close from utils.asserts import assert_grad_close @@ -17,6 +18,22 @@ def test_jac_to_grad(): y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() backward([y1, y2]) # param now has a .jac field - jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + jac_to_grad([param], UPGrad()) # param now has a .grad field assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + + +def test_jac_to_grad2(): + import torch + + from torchjd.aggregation import UPGradWeighting + from torchjd.autojac import backward, jac_to_grad + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + + backward([y1, y2]) + weights = jac_to_grad([param], UPGradWeighting()) + assert_close(weights, torch.tensor([1.0, 1.0]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index a3f83097..aac60060 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -82,7 +82,7 @@ def test_row_mismatch(): def test_no_tensors(): """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" - jac_to_grad([], aggregator=UPGrad()) + jac_to_grad([], UPGrad()) @mark.parametrize("retain_jac", [True, False]) From 13523cc5cd94526eefb43006b51a53d9e6d0be6d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 20 Feb 2026 10:46:27 +0100 Subject: [PATCH 2/2] Add test with `Weighting` --- tests/unit/autojac/test_jac_to_grad.py | 33 +++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index aac60060..5185a228 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -1,8 +1,18 @@ from pytest import mark, raises +from torch.testing import assert_close from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac from utils.tensors import tensor_ -from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad +from torchjd.aggregation import ( + Aggregator, + Mean, + MeanWeighting, + PCGrad, + PCGradWeighting, + UPGrad, + UPGradWeighting, + Weighting, +) from torchjd.autojac._jac_to_grad import jac_to_grad @@ -25,6 +35,27 @@ def test_various_aggregators(aggregator: Aggregator): assert_grad_close(t2, g2) +@mark.parametrize("weighting", [MeanWeighting(), UPGradWeighting(), PCGradWeighting()]) +def test_various_weightings(weighting: Weighting): + """Tests that jac_to_grad works for various aggregators.""" + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + expected_weights = weighting(jac) + expected_grad = expected_weights @ jac + g1 = expected_grad[0] + g2 = expected_grad[1:] + + weights = jac_to_grad([t1, t2], weighting) + + assert_close(weights, expected_weights) + assert_grad_close(t1, g1) + assert_grad_close(t2, g2) + + def test_single_tensor(): """Tests that jac_to_grad works when a single tensor is provided."""