diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d1f0147..ee74f58f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,10 +68,8 @@ changelog does not include internal changes that do not affect the user. - `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only. Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` => `generalized_weighting(generalized_gramian)`. -- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency - of `autojac`. -- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory - efficiency of `autojac`. +- Removed several unnecessary memory duplications. This should significantly improve the memory + efficiency and speed of `autojac`. - Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0 to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 58a2af82..7eb9acd3 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -35,11 +35,20 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: first dimension). """ - contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim - indices_source = list(range(t.ndim - contracted_dims)) - indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) - transposed = t.movedim(indices_source, indices_dest) - gramian = torch.tensordot(t, transposed, dims=contracted_dims) + # Optimization: it's faster to do that than moving dims and using tensordot, and this case + # happens very often, sometimes hundreds of times for a single jac_to_grad. + if contracted_dims == -1: + matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1) + + gramian = matrix @ matrix.T + + else: + contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim + indices_source = list(range(t.ndim - contracted_dims)) + indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) + transposed = t.movedim(indices_source, indices_dest) + gramian = torch.tensordot(t, transposed, dims=contracted_dims) + return cast(PSDTensor, gramian) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 56eadde9..1aee6ee1 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,12 +1,13 @@ +from collections import deque from collections.abc import Iterable -from typing import overload +from typing import TypeGuard, cast, overload import torch -from torch import Tensor +from torch import Tensor, nn -from torchjd._linalg import Matrix +from torchjd._linalg import Matrix, PSDMatrix, compute_gramian from torchjd.aggregation import Aggregator, Weighting -from torchjd.aggregation._aggregator_bases import WeightedAggregator +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension @@ -16,7 +17,18 @@ def jac_to_grad( tensors: Iterable[Tensor], /, - aggregator: WeightedAggregator, + aggregator: GramianWeightedAggregator, + *, + retain_jac: bool = False, + optimize_gramian_computation: bool = False, +) -> Tensor: ... + + +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + /, + aggregator: WeightedAggregator, # Not a GramianWA, because overloads are checked in order *, retain_jac: bool = False, ) -> Tensor: ... @@ -38,6 +50,7 @@ def jac_to_grad( aggregator: Aggregator, *, retain_jac: bool = False, + optimize_gramian_computation: bool = False, ) -> Tensor | None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result @@ -50,6 +63,11 @@ def jac_to_grad( the Jacobians, ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. + :param optimize_gramian_computation: When the ``aggregator`` computes weights based on the + Gramian of the Jacobian, it's possible to skip the concatenation of the Jacobians and to + instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This + saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We + advise to try this optimization if memory is an issue for you. Defaults to ``False``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -96,13 +114,46 @@ def jac_to_grad( if len(tensors_) == 0: raise ValueError("The `tensors` parameter cannot be empty.") - jacobians = [t.jac for t in tensors_] - + jacobians = deque(t.jac for t in tensors_) check_consistent_first_dimension(jacobians, "tensors.jac") if not retain_jac: _free_jacs(tensors_) + if optimize_gramian_computation: + if not _can_skip_jacobian_combination(aggregator): + raise ValueError( + "In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must " + "provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached" + " to it." + ) + + gradients, weights = _gramian_based(aggregator, jacobians) + else: + gradients, weights = _jacobian_based(aggregator, jacobians, tensors_) + accumulate_grads(tensors_, gradients) + + return weights + + +def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]: + return ( + isinstance(aggregator, GramianWeightedAggregator) + and not _has_forward_hook(aggregator) + and not _has_forward_hook(aggregator.weighting) + ) + + +def _has_forward_hook(module: nn.Module) -> bool: + """Return whether the module has any forward hook registered.""" + return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0 + + +def _jacobian_based( + aggregator: Aggregator, + jacobians: deque[Tensor], + tensors: list[TensorWithJac], +) -> tuple[list[Tensor], Tensor | None]: jacobian_matrix = _unite_jacobians(jacobians) weights: Tensor | None = None @@ -124,13 +175,36 @@ def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> No handle.remove() else: gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, tensors_) - accumulate_grads(tensors_, gradients) - return weights + gradients = _disunite_gradient(gradient_vector, tensors) + return gradients, weights + + +def _gramian_based( + aggregator: GramianWeightedAggregator, + jacobians: deque[Tensor], +) -> tuple[list[Tensor], Tensor]: + weighting = aggregator.gramian_weighting + gramian = _compute_gramian_sum(jacobians) + weights = weighting(gramian) + + gradients = list[Tensor]() + while jacobians: + jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap + gradients.append(torch.tensordot(weights, jacobian, dims=1)) + + return gradients, weights + + +def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix: + gramian = sum([compute_gramian(matrix) for matrix in jacobians]) + return cast(PSDMatrix, gramian) -def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: - jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] +def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor: + jacobian_matrices = list[Tensor]() + while jacobians: + jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap + jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1)) jacobian_matrix = torch.concat(jacobian_matrices, dim=1) return jacobian_matrix diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index d3285559..fdcc7ce1 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -113,8 +113,9 @@ def check_consistent_first_dimension( :param jacobians: Sequence of Jacobian tensors to validate. :param variable_name: Name of the variable to include in the error message. """ + if len(jacobians) > 0 and not all( - jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:] + jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians ): raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.") diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 9731c4af..2db0b7f3 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -7,18 +7,35 @@ from utils.tensors import tensor_ from torchjd.aggregation import ( + IMTLG, + MGDA, Aggregator, + AlignedMTL, ConFIG, + Constant, + DualProj, + GradDrop, + Krum, Mean, PCGrad, + Random, + Sum, + TrimmedMean, UPGrad, ) -from torchjd.aggregation._aggregator_bases import WeightedAggregator -from torchjd.autojac._jac_to_grad import jac_to_grad +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator +from torchjd.autojac._jac_to_grad import ( + _can_skip_jacobian_combination, + _has_forward_hook, + jac_to_grad, +) -@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()]) -def test_various_aggregators(aggregator: Aggregator) -> None: +@mark.parametrize( + ["aggregator", "optimize"], + [(Mean(), False), (UPGrad(), True), (UPGrad(), False), (PCGrad(), True), (ConFIG(), False)], +) +def test_various_aggregators(aggregator: Aggregator, optimize: bool) -> None: """ Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights should also be returned. For the others, None should be returned. @@ -33,7 +50,11 @@ def test_various_aggregators(aggregator: Aggregator) -> None: g1 = expected_grad[0] g2 = expected_grad[1:] - optional_weights = jac_to_grad([t1, t2], aggregator) + if optimize: + assert isinstance(aggregator, GramianWeightedAggregator) + optional_weights = jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) + else: + optional_weights = jac_to_grad([t1, t2], aggregator) assert_grad_close(t1, g1) assert_grad_close(t2, g2) @@ -125,6 +146,110 @@ def test_jacs_are_freed(retain_jac: bool) -> None: check(t2) +def test_has_forward_hook() -> None: + """Tests that _has_forward_hook correctly detects the presence of forward hooks.""" + + module = UPGrad() + + def dummy_forward_hook(_module, _input, _output) -> Tensor: + return _output + + def dummy_forward_pre_hook(_module, _input) -> Tensor: + return _input + + def dummy_backward_hook(_module, _grad_input, _grad_output) -> Tensor: + return _grad_input + + def dummy_backward_pre_hook(_module, _grad_output) -> Tensor: + return _grad_output + + # Module with no hooks or backward hooks only should return False + assert not _has_forward_hook(module) + module.register_full_backward_hook(dummy_backward_hook) + assert not _has_forward_hook(module) + module.register_full_backward_pre_hook(dummy_backward_pre_hook) + assert not _has_forward_hook(module) + + # Module with forward hook should return True + handle1 = module.register_forward_hook(dummy_forward_hook) + assert _has_forward_hook(module) + handle2 = module.register_forward_hook(dummy_forward_hook) + assert _has_forward_hook(module) + handle1.remove() + assert _has_forward_hook(module) + handle2.remove() + assert not _has_forward_hook(module) + + # Module with forward pre-hook should return True + handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook) + assert _has_forward_hook(module) + handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook) + assert _has_forward_hook(module) + handle3.remove() + assert _has_forward_hook(module) + handle4.remove() + assert not _has_forward_hook(module) + + +_PARAMETRIZATIONS = [ + (AlignedMTL(), True), + (DualProj(), True), + (IMTLG(), True), + (Krum(n_byzantine=1), True), + (MGDA(), True), + (PCGrad(), True), + (UPGrad(), True), + (ConFIG(), False), + (Constant(tensor_([0.5, 0.5])), False), + (GradDrop(), False), + (Mean(), False), + (Random(), False), + (Sum(), False), + (TrimmedMean(trim_number=1), False), +] + +try: + from torchjd.aggregation import CAGrad + + _PARAMETRIZATIONS.append((CAGrad(c=0.5), True)) +except ImportError: + pass + +try: + from torchjd.aggregation import NashMTL + + _PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False)) +except ImportError: + pass + + +@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS) +def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) -> None: + """ + Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used. + """ + + assert _can_skip_jacobian_combination(aggregator) == expected + handle = aggregator.register_forward_hook(lambda _module, _input, output: output) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + handle = aggregator.register_forward_pre_hook(lambda _module, input: input) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + + if isinstance(aggregator, GramianWeightedAggregator): + handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + + def test_noncontiguous_jac() -> None: """Tests that jac_to_grad works when the .jac field is non-contiguous.""" @@ -185,3 +310,20 @@ def hook_inner(_module: Any, _input: Any, weights: Tensor) -> Tensor: weights = jac_to_grad([t], aggregator) assert_close(weights, aggregator.weighting(jac)) + + +def test_optimize_gramian_computation_error() -> None: + """ + Tests that using optimize_gramian_computation on an incompatible aggregator raises an error. + """ + + aggregator = ConFIG() + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) # ty:ignore[invalid-argument-type]