Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,52 @@
from collections.abc import Iterable
from typing import overload

import torch
from torch import Tensor

from torchjd.aggregation import Aggregator
from torchjd.aggregation._aggregator_bases import WeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: WeightedAggregator,
*,
retain_jac: bool = False,
) -> Tensor: ...


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order
*,
retain_jac: bool = False,
) -> None: ...


def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator,
*,
retain_jac: bool = False,
) -> None:
) -> Tensor | None:
r"""
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
into their ``.grad`` fields.
:param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must
have the same first dimension (e.g. number of losses).
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
:param aggregator: The aggregator used to reduce the Jacobians into gradients. If it uses a
:class:`Weighting <torchjd.aggregation._weighting_bases.Weighting>` to combine the rows of
the Jacobians, ``jac_to_grad`` will also return the computed weights.
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
used. Defaults to ``False``.
Expand All @@ -48,12 +72,15 @@ def jac_to_grad(
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2]) # param now has a .jac field
>>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
>>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field
>>> param.grad
tensor([-1., 1.])
tensor([0.5000, 2.5000])
>>> weights
tensor([0.5, 0.5])
The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the
weights used to combine the Jacobian are equal because there was no conflict.
"""

tensors_ = list[TensorWithJac]()
Expand All @@ -66,7 +93,7 @@ def jac_to_grad(
tensors_.append(t)

if len(tensors_) == 0:
return
return None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return an empty tensor of weights when the aggregator is weighted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe not. Empty weights would be with tensors with 0 rows, but here we have 0 tensors.

Copy link
Contributor Author

@ValerianRey ValerianRey Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also raise an error in that case. Worst case would be that some users would have to do

if len(tensors) > 0:
    jac_to_grad(tensors, A)

So that they avoid gettting this error in the niche cases where the number of elements of tensors varies between iterations and may be 0.


jacobians = [t.jac for t in tensors_]

Expand All @@ -76,9 +103,15 @@ def jac_to_grad(
_free_jacs(tensors_)

jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
if isinstance(aggregator, WeightedAggregator):
weights = aggregator.weighting(jacobian_matrix)
gradient_vector = weights @ jacobian_matrix
else:
weights = None
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, tensors_)
accumulate_grads(tensors_, gradients)
return weights


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
Expand Down
7 changes: 5 additions & 2 deletions tests/doc/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
the obtained `.grad` field.
"""

from torch.testing import assert_close
from utils.asserts import assert_grad_close

from torchjd.aggregation import UPGrad
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I accidentally made this move out of the function.



def test_jac_to_grad() -> None:
import torch

from torchjd.aggregation import UPGrad
from torchjd.autojac import backward, jac_to_grad

param = torch.tensor([1.0, 2.0], requires_grad=True)
# Compute arbitrary quantities that are function of param
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()
backward([y1, y2]) # param now has a .jac field
jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
weights = jac_to_grad([param], UPGrad()) # param now has a .grad field

assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
assert_close(weights, torch.tensor([0.5, 0.5]), rtol=0.0, atol=0.0)
28 changes: 23 additions & 5 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from pytest import mark, raises
from torch.testing import assert_close
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
from utils.tensors import tensor_

from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
from torchjd.aggregation import (
Aggregator,
ConFIG,
Mean,
PCGrad,
UPGrad,
)
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.autojac._jac_to_grad import jac_to_grad


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()])
def test_various_aggregators(aggregator: Aggregator) -> None:
"""Tests that jac_to_grad works for various aggregators."""
"""
Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights
should also be returned. For the others, None should be returned.
"""

t1 = tensor_(1.0, requires_grad=True)
t2 = tensor_([2.0, 3.0], requires_grad=True)
Expand All @@ -19,11 +30,18 @@ def test_various_aggregators(aggregator: Aggregator) -> None:
g1 = expected_grad[0]
g2 = expected_grad[1:]

jac_to_grad([t1, t2], aggregator)
optional_weights = jac_to_grad([t1, t2], aggregator)

assert_grad_close(t1, g1)
assert_grad_close(t2, g2)

if isinstance(aggregator, WeightedAggregator):
assert optional_weights is not None
expected_weights = aggregator.weighting(jac)
assert_close(optional_weights, expected_weights)
else:
assert optional_weights is None


def test_single_tensor() -> None:
"""Tests that jac_to_grad works when a single tensor is provided."""
Expand Down Expand Up @@ -82,7 +100,7 @@ def test_row_mismatch() -> None:
def test_no_tensors() -> None:
"""Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""

jac_to_grad([], aggregator=UPGrad())
jac_to_grad([], UPGrad())


@mark.parametrize("retain_jac", [True, False])
Expand Down