Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,57 @@
from collections.abc import Iterable
from typing import overload

import torch
from torch import Tensor

from torchjd.aggregation import Aggregator
from torchjd.aggregation import Aggregator, Weighting

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator,
/,
*,
retain_jac: bool = False,
) -> None: ...


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
weighting: Weighting,
/,
*,
retain_jac: bool = False,
) -> None:
) -> Tensor: ...


def jac_to_grad(
tensors: Iterable[Tensor],
method: Aggregator | Weighting,
/,
*,
retain_jac: bool = False,
) -> Tensor | None:
r"""
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
into their ``.grad`` fields.

:param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must
have the same first dimension (e.g. number of losses).
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
:param method: The method used to reduce the Jacobians into gradients. Can be an
:class:`Aggregator <torchjd.aggregation._aggregator_bases.Aggregator>` or a
:class:`Weighting <torchjd.aggregation._aggregator_bases.Weighting>` in which case
``jac_to_grad`` also returns the weights used to aggregator the Jacobians.
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
used. Defaults to ``False``.
:returns: If ``method`` is a
:class:`Weighting <torchjd.aggregation._aggregator_bases.Weighting>`, returns the weights
used to aggregate the Jacobians, otherwise ``None``.

.. note::
This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
Expand All @@ -48,9 +75,32 @@ def jac_to_grad(
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2]) # param now has a .jac field
>>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
>>> jac_to_grad([param], UPGrad()) # param now has a .grad field
>>> param.grad
tensor([-1., 1.])
tensor([0.5000, 2.5000])

The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.

.. admonition::
Example

This is the same example as before except that we also obtain the weights

>>> import torch
>>>
>>> from torchjd.aggregation import UPGradWeighting
>>> from torchjd.autojac import backward, jac_to_grad
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2])
>>> weights = jac_to_grad([param], UPGradWeighting())
>>> weights
tensor([1., 1.])

The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
Expand All @@ -66,7 +116,7 @@ def jac_to_grad(
tensors_.append(t)

if len(tensors_) == 0:
return
return None

jacobians = [t.jac for t in tensors_]

Expand All @@ -76,9 +126,15 @@ def jac_to_grad(
_free_jacs(tensors_)

jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
if isinstance(method, Weighting):
weights = method(jacobian_matrix)
gradient_vector = weights @ jacobian_matrix
else:
weights = None
gradient_vector = method(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, tensors_)
accumulate_grads(tensors_, gradients)
return weights


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
Expand Down
19 changes: 18 additions & 1 deletion tests/doc/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
the obtained `.grad` field.
"""

from torch.testing import assert_close
from utils.asserts import assert_grad_close


Expand All @@ -17,6 +18,22 @@ def test_jac_to_grad():
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()
backward([y1, y2]) # param now has a .jac field
jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
jac_to_grad([param], UPGrad()) # param now has a .grad field

assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)


def test_jac_to_grad2():
import torch

from torchjd.aggregation import UPGradWeighting
from torchjd.autojac import backward, jac_to_grad

param = torch.tensor([1.0, 2.0], requires_grad=True)
# Compute arbitrary quantities that are function of param
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()

backward([y1, y2])
weights = jac_to_grad([param], UPGradWeighting())
assert_close(weights, torch.tensor([1.0, 1.0]), rtol=0.0, atol=1e-04)
35 changes: 33 additions & 2 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from pytest import mark, raises
from torch.testing import assert_close
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
from utils.tensors import tensor_

from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
from torchjd.aggregation import (
Aggregator,
Mean,
MeanWeighting,
PCGrad,
PCGradWeighting,
UPGrad,
UPGradWeighting,
Weighting,
)
from torchjd.autojac._jac_to_grad import jac_to_grad


Expand All @@ -25,6 +35,27 @@ def test_various_aggregators(aggregator: Aggregator):
assert_grad_close(t2, g2)


@mark.parametrize("weighting", [MeanWeighting(), UPGradWeighting(), PCGradWeighting()])
def test_various_weightings(weighting: Weighting):
"""Tests that jac_to_grad works for various aggregators."""

t1 = tensor_(1.0, requires_grad=True)
t2 = tensor_([2.0, 3.0], requires_grad=True)
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
t1.__setattr__("jac", jac[:, 0])
t2.__setattr__("jac", jac[:, 1:])
expected_weights = weighting(jac)
expected_grad = expected_weights @ jac
g1 = expected_grad[0]
g2 = expected_grad[1:]

weights = jac_to_grad([t1, t2], weighting)

assert_close(weights, expected_weights)
assert_grad_close(t1, g1)
assert_grad_close(t2, g2)


def test_single_tensor():
"""Tests that jac_to_grad works when a single tensor is provided."""

Expand Down Expand Up @@ -82,7 +113,7 @@ def test_row_mismatch():
def test_no_tensors():
"""Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""

jac_to_grad([], aggregator=UPGrad())
jac_to_grad([], UPGrad())


@mark.parametrize("retain_jac", [True, False])
Expand Down
Loading