diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst new file mode 100644 index 00000000..5fd2f7bf --- /dev/null +++ b/docs/source/docs/aggregation/gradvac.rst @@ -0,0 +1,22 @@ +:hide-toc: + +GradVac +======= + +.. autodata:: torchjd.aggregation.DEFAULT_GRADVAC_EPS + +The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the +per-block cosine statistics in GradVac: + +* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and + ``shared_params``. +* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under + ``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation). +* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use + the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass + to a shared-gradient helper). + +.. autoclass:: torchjd.aggregation.GradVac + :members: + :undoc-members: + :exclude-members: forward diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index c15d5980..64ba6f63 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes dualproj.rst flattening.rst graddrop.rst + gradvac.rst imtl_g.rst krum.rst mean.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7..d6c39602 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,6 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop +from ._gradvac import DEFAULT_GRADVAC_EPS, GradVac from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -87,11 +88,13 @@ "ConFIG", "Constant", "ConstantWeighting", + "DEFAULT_GRADVAC_EPS", "DualProj", "DualProjWeighting", "Flattening", "GeneralizedWeighting", "GradDrop", + "GradVac", "IMTLG", "IMTLGWeighting", "Krum", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py new file mode 100644 index 00000000..3d8f0c0f --- /dev/null +++ b/src/torchjd/aggregation/_gradvac.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from collections.abc import Iterable + +import torch +import torch.nn as nn +from torch import Tensor + +from torchjd._linalg import Matrix + +from ._aggregator_bases import Aggregator +from ._utils.non_differentiable import raise_non_differentiable_error + +#: Default small constant added to denominators for numerical stability. +DEFAULT_GRADVAC_EPS = 1e-8 + + +def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: + """ + Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate + ``encoder.modules()`` and append the total number of elements in each module that has no child + submodules and registers at least one parameter. + """ + + return tuple( + sum(w.numel() for w in module.parameters()) + for module in encoder.modules() + if len(module._modules) == 0 and len(module._parameters) > 0 + ) + + +def _gradvac_all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: + """One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout).""" + + return tuple(p.numel() for p in shared_params) + + +class GradVac(Aggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine + (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in + Massively Multilingual Models (ICLR 2021 Spotlight) + `_. + + The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task + gradients. For each task :math:`i`, rows are visited in a random order; for each other task + :math:`j` and each parameter block :math:`k`, the cosine correlation :math:`\rho_{ijk}` between + the (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` on that block is compared to an EMA target :math:`\bar{\rho}_{ijk}`. When + :math:`\rho_{ijk} < \bar{\rho}_{ijk}`, a closed-form correction adds a scaled copy of + :math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\bar{\rho}_{ijk} \leftarrow (1-\beta)\bar{\rho}_{ijk} + \beta \rho_{ijk}`. The aggregated + vector is the sum of the modified rows. + + This aggregator is stateful: it keeps :math:`\bar{\rho}` across calls. Use :meth:`reset` when + the number of tasks, parameter dimension, grouping, device, or dtype changes. + + **Parameter granularity** is selected by ``group_type`` (integer, default ``0``). It defines how + each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets + :math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally: + + * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. + Cosine similarity is taken between entire task gradients. Do not pass ``encoder`` or + ``shared_params``. + * ``1`` — **all layer** (``all_layer``): one block per leaf ``nn.Module`` under ``encoder`` that + holds parameters (same rule as iterating ``encoder.modules()`` and selecting leaves with + parameters). Pass ``encoder``; ``shared_params`` must be omitted. + * ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in iteration + order. That order must match how Jacobian columns are laid out for those shared parameters. + Pass ``shared_params``; ``encoder`` must be omitted. + + :param beta: EMA decay for :math:`\bar{\rho}` (paper default ``0.5``). + :param group_type: Granularity of parameter grouping; see **Parameter granularity** above. + :param encoder: Module whose subtree defines ``all_layer`` blocks when ``group_type == 1``. + :param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and + order when ``group_type == 2``. It is materialized once at construction. + :param eps: Small positive constant added to denominators when computing cosines and the + vaccine weight (default :data:`~torchjd.aggregation.DEFAULT_GRADVAC_EPS`). You may read or + assign the :attr:`eps` attribute between steps to tune numerical behavior. + + .. note:: + GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner + products, not only a Gram matrix. Only the autojac path is supported. + + .. note:: + Task-order shuffling uses the global PyTorch RNG (``torch.randperm``). Seed it with + ``torch.manual_seed`` if you need reproducibility. + """ + + def __init__( + self, + beta: float = 0.5, + group_type: int = 0, + encoder: nn.Module | None = None, + shared_params: Iterable[Tensor] | None = None, + eps: float = DEFAULT_GRADVAC_EPS, + ) -> None: + super().__init__() + if not (0.0 <= beta <= 1.0): + raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") + if group_type not in (0, 1, 2): + raise ValueError( + "Parameter `group_type` must be 0 (whole_model), 1 (all_layer), or 2 (all_matrix). " + f"Found group_type={group_type!r}.", + ) + params_tuple: tuple[Tensor, ...] = () + fixed_block_sizes: tuple[int, ...] | None + if group_type == 0: + if encoder is not None: + raise ValueError("Parameter `encoder` must be None when `group_type == 0`.") + if shared_params is not None: + raise ValueError("Parameter `shared_params` must be None when `group_type == 0`.") + fixed_block_sizes = None + elif group_type == 1: + if encoder is None: + raise ValueError("Parameter `encoder` is required when `group_type == 1`.") + if shared_params is not None: + raise ValueError("Parameter `shared_params` must be None when `group_type == 1`.") + fixed_block_sizes = _gradvac_all_layer_group_sizes(encoder) + if sum(fixed_block_sizes) == 0: + raise ValueError("Parameter `encoder` has no parameters in any leaf module.") + else: + if shared_params is None: + raise ValueError("Parameter `shared_params` is required when `group_type == 2`.") + if encoder is not None: + raise ValueError("Parameter `encoder` must be None when `group_type == 2`.") + params_tuple = tuple(shared_params) + if len(params_tuple) == 0: + raise ValueError( + "Parameter `shared_params` must be non-empty when `group_type == 2`." + ) + fixed_block_sizes = _gradvac_all_matrix_group_sizes(params_tuple) + + if eps <= 0.0: + raise ValueError(f"Parameter `eps` must be positive. Found eps={eps!r}.") + + self._beta = beta + self._group_type = group_type + self._encoder = encoder + self._shared_params_len = len(params_tuple) + self._fixed_block_sizes = fixed_block_sizes + self._eps = float(eps) + + self._rho_t: Tensor | None = None + self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None + + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def eps(self) -> float: + """Small positive constant added to denominators for numerical stability.""" + + return self._eps + + @eps.setter + def eps(self, value: float) -> None: + v = float(value) + if v <= 0.0: + raise ValueError(f"Attribute `eps` must be positive. Found eps={value!r}.") + self._eps = v + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._rho_t = None + self._state_key = None + + def __repr__(self) -> str: + enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)" + sp = "None" if self._group_type != 2 else f"n_params={self._shared_params_len}" + return ( + f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, " + f"encoder={enc}, shared_params={sp}, eps={self._eps!r})" + ) + + def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]: + if self._group_type == 0: + return (n,) + assert self._fixed_block_sizes is not None + sizes = self._fixed_block_sizes + if sum(sizes) != n: + raise ValueError( + "The Jacobian width `D` must equal the sum of block sizes implied by " + f"`encoder` or `shared_params` for this `group_type`. Found D={n}, " + f"sum(block_sizes)={sum(sizes)}.", + ) + return sizes + + def _ensure_state( + self, + m: int, + n: int, + sizes: tuple[int, ...], + device: torch.device, + dtype: torch.dtype, + ) -> None: + key = (m, n, sizes, device, dtype) + num_groups = len(sizes) + if self._state_key != key or self._rho_t is None: + self._rho_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) + self._state_key = key + + def forward(self, matrix: Matrix, /) -> Tensor: + grads = matrix + m, n = grads.shape + if m == 0 or n == 0: + return torch.zeros(n, dtype=grads.dtype, device=grads.device) + + sizes = self._resolve_segment_sizes(n) + device = grads.device + dtype = grads.dtype + self._ensure_state(m, n, sizes, device, dtype) + assert self._rho_t is not None + + rho_t = self._rho_t + beta = self._beta + eps = self.eps + + pc_grads = grads.clone() + offsets = [0] + for s in sizes: + offsets.append(offsets[-1] + s) + + for i in range(m): + others = [j for j in range(m) if j != i] + perm = torch.randperm(len(others)) + order = perm.tolist() + shuffled_js = [others[idx] for idx in order] + + for j in shuffled_js: + for k in range(len(sizes)): + beg, end = offsets[k], offsets[k + 1] + slice_i = pc_grads[i, beg:end] + slice_j = grads[j, beg:end] + + norm_i = slice_i.norm() + norm_j = slice_j.norm() + denom = norm_i * norm_j + eps + rho_ijk = slice_i.dot(slice_j) / denom + + bar = rho_t[i, j, k] + if rho_ijk < bar: + sqrt_1_rho2 = (1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt() + sqrt_1_bar2 = (1.0 - bar * bar).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_bar2 + eps + w = norm_i * (bar * sqrt_1_rho2 - rho_ijk * sqrt_1_bar2) / denom_w + pc_grads[i, beg:end] = slice_i + slice_j * w + + rho_t[i, j, k] = (1.0 - beta) * bar + beta * rho_ijk + + return pc_grads.sum(dim=0) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py new file mode 100644 index 00000000..49a8770a --- /dev/null +++ b/tests/unit/aggregation/test_gradvac.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +from pytest import mark, raises +from torch import Tensor, tensor +from torch.testing import assert_close +from utils.tensors import ones_, randn_ + +from torchjd.aggregation import DEFAULT_GRADVAC_EPS, GradVac + +from ._asserts import assert_expected_structure, assert_non_differentiable +from ._inputs import scaled_matrices, typical_matrices + +scaled_pairs = [(GradVac(), m) for m in scaled_matrices] +typical_pairs = [(GradVac(), m) for m in typical_matrices] +requires_grad_pairs = [(GradVac(), ones_(3, 5, requires_grad=True))] + + +def test_repr() -> None: + g = GradVac() + assert repr(g).startswith("GradVac(") + assert "beta=" in repr(g) + assert "group_type=" in repr(g) + assert "encoder=" in repr(g) + assert "eps=" in repr(g) + + +def test_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVac(beta=-0.1) + with raises(ValueError, match="beta"): + GradVac(beta=1.1) + + +def test_eps_non_positive() -> None: + with raises(ValueError, match="eps"): + GradVac(eps=0.0) + with raises(ValueError, match="eps"): + GradVac(eps=-1e-9) + + +def test_eps_setter_rejects_non_positive() -> None: + g = GradVac() + with raises(ValueError, match="eps"): + g.eps = 0.0 + + +def test_default_eps_constant() -> None: + assert DEFAULT_GRADVAC_EPS == 1e-8 + assert GradVac().eps == DEFAULT_GRADVAC_EPS + + +def test_eps_can_be_changed_between_steps() -> None: + j = tensor([[1.0, 0.0], [0.0, 1.0]]) + agg = GradVac() + agg.eps = 1e-6 + assert agg(j).isfinite().all() + agg.reset() + agg.eps = 1e-10 + assert agg(j).isfinite().all() + + +def test_group_type_invalid() -> None: + with raises(ValueError, match="group_type"): + GradVac(group_type=3) + + +def test_group_type_0_rejects_encoder() -> None: + net = nn.Linear(1, 1) + with raises(ValueError, match="encoder"): + GradVac(encoder=net) + + +def test_group_type_0_rejects_shared_params() -> None: + p = nn.Parameter(tensor([1.0])) + with raises(ValueError, match="shared_params"): + GradVac(shared_params=[p]) + + +def test_group_type_1_requires_encoder() -> None: + with raises(ValueError, match="encoder"): + GradVac(group_type=1) + + +def test_group_type_1_rejects_shared_params() -> None: + net = nn.Linear(1, 1) + p = nn.Parameter(tensor([1.0])) + with raises(ValueError, match="shared_params"): + GradVac(group_type=1, encoder=net, shared_params=[p]) + + +def test_group_type_2_requires_shared_params() -> None: + with raises(ValueError, match="shared_params"): + GradVac(group_type=2) + + +def test_group_type_2_rejects_encoder() -> None: + net = nn.Linear(1, 1) + with raises(ValueError, match="encoder"): + GradVac(group_type=2, encoder=net, shared_params=list(net.parameters())) + + +def test_encoder_without_leaf_parameters() -> None: + class Empty(nn.Module): + pass + + with raises(ValueError, match="encoder"): + GradVac(group_type=1, encoder=Empty()) + + +def test_shared_params_empty() -> None: + with raises(ValueError, match="shared_params"): + GradVac(group_type=2, shared_params=()) + + +def test_group_type_1_forward() -> None: + net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) + d = sum(p.numel() for p in net.parameters()) + j = randn_((2, d)) + torch.manual_seed(0) + out = GradVac(group_type=1, encoder=net)(j) + assert out.shape == (d,) + assert out.isfinite().all() + + +def test_group_type_2_forward() -> None: + net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) + params = list(net.parameters()) + d = sum(p.numel() for p in params) + j = randn_((2, d)) + torch.manual_seed(0) + out = GradVac(group_type=2, shared_params=params)(j) + assert out.shape == (d,) + assert out.isfinite().all() + + +def test_jacobian_width_mismatch() -> None: + net = nn.Linear(2, 2) + d = sum(p.numel() for p in net.parameters()) + agg = GradVac(group_type=1, encoder=net) + with raises(ValueError, match="Jacobian width"): + agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)])) + + +def test_zero_rows_returns_zero_vector() -> None: + out = GradVac()(tensor([]).reshape(0, 3)) + assert_close(out, tensor([0.0, 0.0, 0.0])) + + +def test_zero_columns_returns_zero_vector() -> None: + """Handled inside forward before grouping validation.""" + + out = GradVac()(tensor([]).reshape(2, 0)) + assert out.shape == (0,) + + +def test_reproducible_with_manual_seed() -> None: + j = randn_((3, 8)) + torch.manual_seed(12345) + a1 = GradVac(beta=0.3) + out1 = a1(j) + torch.manual_seed(12345) + a2 = GradVac(beta=0.3) + out2 = a2(j) + assert_close(out1, out2) + + +def test_reset_restores_first_step_behavior() -> None: + j = tensor([[1.0, 0.0], [0.0, 1.0]]) + torch.manual_seed(7) + agg = GradVac(beta=0.5) + first = agg(j) + agg(j) + agg.reset() + torch.manual_seed(7) + assert_close(first, agg(j)) + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: GradVac, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) +def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: + assert_non_differentiable(aggregator, matrix)