From f09a485cf5cfcbdc3728c8df893a48af4d6aeb0d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 20 Feb 2026 10:36:17 +0100 Subject: [PATCH 1/7] feat(autojac): Make `jac_to_grad` return optional weights. * Change `aggregator: Aggregator` to `method: Aggregator | Weighting` and return type to optional `Tensor`. * Make `method` positional only. * Add overloads to rename `method` to `aggregator` or `weighting` and link it to output type. * Compute the weights if we provide a weighting and return them. * Update the doc and add a usage example --- src/torchjd/autojac/_jac_to_grad.py | 72 +++++++++++++++++++++++--- tests/doc/test_jac_to_grad.py | 19 ++++++- tests/unit/autojac/test_jac_to_grad.py | 2 +- 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9467b4ef..d2c15ab5 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,30 +1,57 @@ from collections.abc import Iterable +from typing import overload import torch from torch import Tensor -from torchjd.aggregation import Aggregator +from torchjd.aggregation import Aggregator, Weighting from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension +@overload def jac_to_grad( tensors: Iterable[Tensor], - /, aggregator: Aggregator, + /, + *, + retain_jac: bool = False, +) -> None: ... + + +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + weighting: Weighting, + /, *, retain_jac: bool = False, -) -> None: +) -> Tensor: ... + + +def jac_to_grad( + tensors: Iterable[Tensor], + method: Aggregator | Weighting, + /, + *, + retain_jac: bool = False, +) -> Tensor | None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result into their ``.grad`` fields. :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses). - :param aggregator: The aggregator used to reduce the Jacobians into gradients. + :param method: The method used to reduce the Jacobians into gradients. Can be an + :class:`Aggregator ` or a + :class:`Weighting ` in which case + ``jac_to_grad`` also returns the weights used to aggregator the Jacobians. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. + :returns: If ``method`` is a + :class:`Weighting `, returns the weights + used to aggregate the Jacobians, otherwise ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -48,9 +75,32 @@ def jac_to_grad( >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2]) # param now has a .jac field - >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + >>> jac_to_grad([param], UPGrad()) # param now has a .grad field >>> param.grad - tensor([-1., 1.]) + tensor([0.5000, 2.5000]) + + The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + + .. admonition:: + Example + + This is the same example as before except that we also obtain the weights + + >>> import torch + >>> + >>> from torchjd.aggregation import UPGradWeighting + >>> from torchjd.autojac import backward, jac_to_grad + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> backward([y1, y2]) + >>> weights = jac_to_grad([param], UPGradWeighting()) + >>> weights + tensor([1., 1.]) The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. @@ -66,7 +116,7 @@ def jac_to_grad( tensors_.append(t) if len(tensors_) == 0: - return + return None jacobians = [t.jac for t in tensors_] @@ -76,9 +126,15 @@ def jac_to_grad( _free_jacs(tensors_) jacobian_matrix = _unite_jacobians(jacobians) - gradient_vector = aggregator(jacobian_matrix) + if isinstance(method, Weighting): + weights = method(jacobian_matrix) + gradient_vector = weights @ jacobian_matrix + else: + weights = None + gradient_vector = method(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) + return weights def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 1f064a6c..dde6cff0 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -3,6 +3,7 @@ the obtained `.grad` field. """ +from torch.testing import assert_close from utils.asserts import assert_grad_close @@ -17,6 +18,22 @@ def test_jac_to_grad(): y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() backward([y1, y2]) # param now has a .jac field - jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field + jac_to_grad([param], UPGrad()) # param now has a .grad field assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) + + +def test_jac_to_grad2(): + import torch + + from torchjd.aggregation import UPGradWeighting + from torchjd.autojac import backward, jac_to_grad + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + + backward([y1, y2]) + weights = jac_to_grad([param], UPGradWeighting()) + assert_close(weights, torch.tensor([1.0, 1.0]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index a3f83097..aac60060 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -82,7 +82,7 @@ def test_row_mismatch(): def test_no_tensors(): """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" - jac_to_grad([], aggregator=UPGrad()) + jac_to_grad([], UPGrad()) @mark.parametrize("retain_jac", [True, False]) From 13523cc5cd94526eefb43006b51a53d9e6d0be6d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 20 Feb 2026 10:46:27 +0100 Subject: [PATCH 2/7] Add test with `Weighting` --- tests/unit/autojac/test_jac_to_grad.py | 33 +++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index aac60060..5185a228 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -1,8 +1,18 @@ from pytest import mark, raises +from torch.testing import assert_close from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac from utils.tensors import tensor_ -from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad +from torchjd.aggregation import ( + Aggregator, + Mean, + MeanWeighting, + PCGrad, + PCGradWeighting, + UPGrad, + UPGradWeighting, + Weighting, +) from torchjd.autojac._jac_to_grad import jac_to_grad @@ -25,6 +35,27 @@ def test_various_aggregators(aggregator: Aggregator): assert_grad_close(t2, g2) +@mark.parametrize("weighting", [MeanWeighting(), UPGradWeighting(), PCGradWeighting()]) +def test_various_weightings(weighting: Weighting): + """Tests that jac_to_grad works for various aggregators.""" + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + expected_weights = weighting(jac) + expected_grad = expected_weights @ jac + g1 = expected_grad[0] + g2 = expected_grad[1:] + + weights = jac_to_grad([t1, t2], weighting) + + assert_close(weights, expected_weights) + assert_grad_close(t1, g1) + assert_grad_close(t2, g2) + + def test_single_tensor(): """Tests that jac_to_grad works when a single tensor is provided.""" From cad89718550460d03173f405f5fc43c77f1ea209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 20 Feb 2026 16:54:22 +0100 Subject: [PATCH 3/7] Make jac_to_grad only take aggregator --- src/torchjd/autojac/_jac_to_grad.py | 60 +++++++++----------------- tests/doc/test_jac_to_grad.py | 22 ++-------- tests/unit/autojac/test_jac_to_grad.py | 41 ++++++------------ 3 files changed, 38 insertions(+), 85 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index d2c15ab5..978df6c1 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -4,7 +4,8 @@ import torch from torch import Tensor -from torchjd.aggregation import Aggregator, Weighting +from torchjd.aggregation import Aggregator +from torchjd.aggregation._aggregator_bases import WeightedAggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension @@ -13,26 +14,26 @@ @overload def jac_to_grad( tensors: Iterable[Tensor], - aggregator: Aggregator, + aggregator: WeightedAggregator, /, *, retain_jac: bool = False, -) -> None: ... +) -> Tensor: ... @overload def jac_to_grad( tensors: Iterable[Tensor], - weighting: Weighting, + aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order /, *, retain_jac: bool = False, -) -> Tensor: ... +) -> None: ... def jac_to_grad( tensors: Iterable[Tensor], - method: Aggregator | Weighting, + aggregator: Aggregator, /, *, retain_jac: bool = False, @@ -43,15 +44,14 @@ def jac_to_grad( :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses). - :param method: The method used to reduce the Jacobians into gradients. Can be an - :class:`Aggregator ` or a - :class:`Weighting ` in which case - ``jac_to_grad`` also returns the weights used to aggregator the Jacobians. + :param aggregator: The aggregator used to reduce the Jacobians into gradients. If a + :class:`WeightedAggregator ` is + provided, ``jac_to_grad`` will also return the weights used to aggregate the Jacobians. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. - :returns: If ``method`` is a - :class:`Weighting `, returns the weights - used to aggregate the Jacobians, otherwise ``None``. + :returns: If ``aggregator`` is a + :class:`WeightedAggregator `, + returns the weights used to aggregate the Jacobians, otherwise ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -75,35 +75,15 @@ def jac_to_grad( >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2]) # param now has a .jac field - >>> jac_to_grad([param], UPGrad()) # param now has a .grad field + >>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field >>> param.grad tensor([0.5000, 2.5000]) - - The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of - :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. - - .. admonition:: - Example - - This is the same example as before except that we also obtain the weights - - >>> import torch - >>> - >>> from torchjd.aggregation import UPGradWeighting - >>> from torchjd.autojac import backward, jac_to_grad - >>> - >>> param = torch.tensor([1., 2.], requires_grad=True) - >>> # Compute arbitrary quantities that are function of param - >>> y1 = torch.tensor([-1., 1.]) @ param - >>> y2 = (param ** 2).sum() - >>> - >>> backward([y1, y2]) - >>> weights = jac_to_grad([param], UPGradWeighting()) >>> weights - tensor([1., 1.]) + tensor([0.5], 0.5]) The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of - :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the + weights used to combine the Jacobian are equal because there was no conflict. """ tensors_ = list[TensorWithJac]() @@ -126,12 +106,12 @@ def jac_to_grad( _free_jacs(tensors_) jacobian_matrix = _unite_jacobians(jacobians) - if isinstance(method, Weighting): - weights = method(jacobian_matrix) + if isinstance(aggregator, WeightedAggregator): + weights = aggregator.weighting(jacobian_matrix) gradient_vector = weights @ jacobian_matrix else: weights = None - gradient_vector = method(jacobian_matrix) + gradient_vector = aggregator(jacobian_matrix) gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) return weights diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index dde6cff0..afef6445 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -6,11 +6,12 @@ from torch.testing import assert_close from utils.asserts import assert_grad_close +from torchjd.aggregation import UPGrad + def test_jac_to_grad(): import torch - from torchjd.aggregation import UPGrad from torchjd.autojac import backward, jac_to_grad param = torch.tensor([1.0, 2.0], requires_grad=True) @@ -18,22 +19,7 @@ def test_jac_to_grad(): y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() backward([y1, y2]) # param now has a .jac field - jac_to_grad([param], UPGrad()) # param now has a .grad field + weights = jac_to_grad([param], UPGrad()) # param now has a .grad field assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04) - - -def test_jac_to_grad2(): - import torch - - from torchjd.aggregation import UPGradWeighting - from torchjd.autojac import backward, jac_to_grad - - param = torch.tensor([1.0, 2.0], requires_grad=True) - # Compute arbitrary quantities that are function of param - y1 = torch.tensor([-1.0, 1.0]) @ param - y2 = (param**2).sum() - - backward([y1, y2]) - weights = jac_to_grad([param], UPGradWeighting()) - assert_close(weights, torch.tensor([1.0, 1.0]), rtol=0.0, atol=1e-04) + assert_close(weights, torch.tensor([0.5, 0.5]), rtol=0.0, atol=0.0) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 5185a228..8e8d7813 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -5,20 +5,21 @@ from torchjd.aggregation import ( Aggregator, + ConFIG, Mean, - MeanWeighting, PCGrad, - PCGradWeighting, UPGrad, - UPGradWeighting, - Weighting, ) +from torchjd.aggregation._aggregator_bases import WeightedAggregator from torchjd.autojac._jac_to_grad import jac_to_grad -@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) +@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()]) def test_various_aggregators(aggregator: Aggregator): - """Tests that jac_to_grad works for various aggregators.""" + """ + Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights + should also be returned. For the others, None should be returned. + """ t1 = tensor_(1.0, requires_grad=True) t2 = tensor_([2.0, 3.0], requires_grad=True) @@ -29,31 +30,17 @@ def test_various_aggregators(aggregator: Aggregator): g1 = expected_grad[0] g2 = expected_grad[1:] - jac_to_grad([t1, t2], aggregator) + optional_weights = jac_to_grad([t1, t2], aggregator) assert_grad_close(t1, g1) assert_grad_close(t2, g2) - -@mark.parametrize("weighting", [MeanWeighting(), UPGradWeighting(), PCGradWeighting()]) -def test_various_weightings(weighting: Weighting): - """Tests that jac_to_grad works for various aggregators.""" - - t1 = tensor_(1.0, requires_grad=True) - t2 = tensor_([2.0, 3.0], requires_grad=True) - jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) - t1.__setattr__("jac", jac[:, 0]) - t2.__setattr__("jac", jac[:, 1:]) - expected_weights = weighting(jac) - expected_grad = expected_weights @ jac - g1 = expected_grad[0] - g2 = expected_grad[1:] - - weights = jac_to_grad([t1, t2], weighting) - - assert_close(weights, expected_weights) - assert_grad_close(t1, g1) - assert_grad_close(t2, g2) + if isinstance(aggregator, WeightedAggregator): + assert optional_weights is not None + expected_weights = aggregator.weighting(jac) + assert_close(optional_weights, expected_weights) + else: + assert optional_weights is None def test_single_tensor(): From f9b728daad5cf7f0b9426c0c91c17e178a15cc96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 20 Feb 2026 16:55:10 +0100 Subject: [PATCH 4/7] Make the aggregator parameter positional & keyword --- src/torchjd/autojac/_jac_to_grad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 978df6c1..9a16ddff 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -14,8 +14,8 @@ @overload def jac_to_grad( tensors: Iterable[Tensor], - aggregator: WeightedAggregator, /, + aggregator: WeightedAggregator, *, retain_jac: bool = False, ) -> Tensor: ... @@ -24,8 +24,8 @@ def jac_to_grad( @overload def jac_to_grad( tensors: Iterable[Tensor], - aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order /, + aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order *, retain_jac: bool = False, ) -> None: ... @@ -33,8 +33,8 @@ def jac_to_grad( def jac_to_grad( tensors: Iterable[Tensor], - aggregator: Aggregator, /, + aggregator: Aggregator, *, retain_jac: bool = False, ) -> Tensor | None: From b701c592da626f5e975080ea68dbaa3268aa41e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 20 Feb 2026 17:23:53 +0100 Subject: [PATCH 5/7] Stop mentioning WeightedAggregator in jac_to_grad docstring --- src/torchjd/autojac/_jac_to_grad.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9a16ddff..59b7b94d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -44,14 +44,14 @@ def jac_to_grad( :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses). - :param aggregator: The aggregator used to reduce the Jacobians into gradients. If a - :class:`WeightedAggregator ` is - provided, ``jac_to_grad`` will also return the weights used to aggregate the Jacobians. + :param aggregator: The aggregator used to reduce the Jacobians into gradients. If it uses a + :class:`Weighting ` to combine the rows of + the Jacobians, ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. - :returns: If ``aggregator`` is a - :class:`WeightedAggregator `, - returns the weights used to aggregate the Jacobians, otherwise ``None``. + :returns: If ``aggregator`` uses a + :class:`Weighting ` returns the weights used + to combine the rows of the Jacobians, otherwise ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all From 5801380ff26a3fdccd9e1fb571401bd3856bbff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 20 Feb 2026 17:27:46 +0100 Subject: [PATCH 6/7] Remove :returns: - We never use it anywhere else in the library, and I think it's quite redundant here --- src/torchjd/autojac/_jac_to_grad.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 59b7b94d..eb953a4e 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -49,9 +49,6 @@ def jac_to_grad( the Jacobians, ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. - :returns: If ``aggregator`` uses a - :class:`Weighting ` returns the weights used - to combine the rows of the Jacobians, otherwise ``None``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all From d24c56a8135bc0c02e354966d46028f353ae3527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Sat, 21 Feb 2026 01:11:55 +0100 Subject: [PATCH 7/7] fix typo Co-authored-by: Pierre Quinton --- src/torchjd/autojac/_jac_to_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index eb953a4e..e76c9c87 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -76,7 +76,7 @@ def jac_to_grad( >>> param.grad tensor([0.5000, 2.5000]) >>> weights - tensor([0.5], 0.5]) + tensor([0.5, 0.5]) The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the