diff --git a/pyproject.toml b/pyproject.toml index 4332dc06..82cbf7b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,9 +14,7 @@ authors = [ requires-python = ">=3.10" dependencies = [ "torch>=2.3.0", # Problems before 2.4.0, especially with autogram. - "quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked "numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2. - "qpsolvers>=1.0.1", # Does not work before 1.0.1 ] classifiers = [ "Development Status :: 4 - Beta", @@ -96,8 +94,6 @@ plot = [ lower_bounds = [ "torch==2.3.0", "numpy==1.21.2", - "quadprog==0.1.9", - "qpsolvers==1.0.1", ] [project.optional-dependencies] diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 7e868f62..b2c052c3 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -1,10 +1,10 @@ from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import PSDMatrix, normalize from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights +from ._utils.dual_cone import project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -20,27 +20,18 @@ class DualProj(GramianWeightedAggregator): :param pref_vector: The preference vector used to combine the rows. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps - self._reg_eps = reg_eps - self._solver: SUPPORTED_SOLVER = solver super().__init__( - DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + DualProjWeighting(pref_vector, norm_eps=norm_eps), ) # This prevents considering the computed weights as constant w.r.t. the matrix. @@ -48,8 +39,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " + f"norm_eps={self._norm_eps})" ) def __str__(self) -> str: @@ -64,29 +55,19 @@ class DualProjWeighting(Weighting[PSDMatrix]): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix, /) -> Tensor: u = self.weighting(gramian) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - w = project_weights(u, G, self.solver) - return w + G = normalize(gramian, self.norm_eps) + return project_weights(u, G) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 45f760be..8dd9c40a 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -1,11 +1,11 @@ import torch from torch import Tensor -from torchjd._linalg import PSDMatrix, normalize, regularize +from torchjd._linalg import PSDMatrix, normalize from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting -from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights +from ._utils.dual_cone import project_weights from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import Weighting @@ -21,27 +21,18 @@ class UPGrad(GramianWeightedAggregator): defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps - self._reg_eps = reg_eps - self._solver: SUPPORTED_SOLVER = solver super().__init__( - UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), + UPGradWeighting(pref_vector, norm_eps=norm_eps), ) # This prevents considering the computed weights as constant w.r.t. the matrix. @@ -49,8 +40,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps=" - f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})" + f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, " + f"norm_eps={self._norm_eps})" ) def __str__(self) -> str: @@ -65,29 +56,20 @@ class UPGradWeighting(Weighting[PSDMatrix]): :param pref_vector: The preference vector to use. If not provided, defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. :param norm_eps: A small value to avoid division by zero when normalizing. - :param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to - numerical errors when computing the gramian, it might not exactly be positive definite. - This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian - ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. """ def __init__( self, pref_vector: Tensor | None = None, norm_eps: float = 0.0001, - reg_eps: float = 0.0001, - solver: SUPPORTED_SOLVER = "quadprog", ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) self.norm_eps = norm_eps - self.reg_eps = reg_eps - self.solver: SUPPORTED_SOLVER = solver def forward(self, gramian: PSDMatrix, /) -> Tensor: U = torch.diag(self.weighting(gramian)) - G = regularize(normalize(gramian, self.norm_eps), self.reg_eps) - W = project_weights(U, G, self.solver) + G = normalize(gramian, self.norm_eps) + W = project_weights(U, G) return torch.sum(W, dim=0) diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py index b076366b..1a3bf036 100644 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ b/src/torchjd/aggregation/_utils/dual_cone.py @@ -1,62 +1,117 @@ -from typing import Literal, TypeAlias - -import numpy as np import torch -from qpsolvers import solve_qp from torch import Tensor -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] - -def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: +def project_weights(U: Tensor, G: Tensor) -> Tensor: """ Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the rows of a matrix whose Gramian is provided. :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. - :return: A tensor of projection weights with the same shape as `U`. """ - G_ = _to_array(G) - U_ = _to_array(U) + shape = U.shape + m = shape[-1] + batch_size = U.numel() // m - W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) + # Cast to float64 for numerical stability + G64 = G.to(dtype=torch.float64) + U_flat = U.to(dtype=torch.float64).reshape(batch_size, m) - return torch.as_tensor(W, device=G.device, dtype=G.dtype) + W = _solve_batch_qp(U_flat, G64) + return W.reshape(shape).to(dtype=G.dtype) -def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: + +def _solve_batch_qp(U: Tensor, G: Tensor) -> Tensor: r""" - Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, - given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies - `\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1]. - - By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: - minimize v^T G v - subject to u \preceq v - - Reference: - [1] `Jacobian Descent For Multi-Objective Optimization `_. - - :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to - project. - :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be - symmetric and positive definite. - :param solver: The quadratic programming solver to use. - """ + Solves a batch of QPs sharing the same cost matrix using ADMM: - m = G.shape[0] - w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) + .. code-block:: text - if w is None: # This may happen when G has large values. - raise ValueError("Failed to solve the quadratic programming problem.") + minimize (1/2) v^T G v + subject to U[i] <= v (componentwise, for each row i of U) + + Three improvements over basic ADMM ensure convergence on ill-conditioned Gramians: - return w + - **Ruiz equilibration** (5 iterations): symmetrically scales G to bring all rows and + columns to unit infinity norm, reducing the effective condition number. + - **Adaptive rho**: the ADMM penalty parameter is updated every ``sqrt(m)`` iterations + when primal and dual residuals are severely imbalanced, triggering a cheap re-factorization. + - **Normalized stopping criteria**: convergence is checked against absolute + relative + tolerances, matching the OSQP / lqp_py conventions. + :param U: Lower-bound matrix of shape ``[B, m]``. + :param G: Shared cost matrix of shape ``[m, m]``, symmetric positive definite. + """ -def _to_array(tensor: Tensor) -> np.ndarray: - """Transforms a tensor into a numpy array with float64 dtype.""" + B, m = U.shape + device = G.device + I_m = torch.eye(m, dtype=torch.float64, device=device) + + # --- Ruiz equilibration --- + # Build D such that G_s = diag(D) @ G @ diag(D) has all row/column inf-norms ≈ 1. + # Variable substitution: v_orig = D * v_scaled => U_scaled = U / D. + G_s = G.clone() + D = torch.ones(m, dtype=torch.float64, device=device) + for _ in range(5): + delta = G_s.abs().amax(dim=1).clamp(min=1e-10).rsqrt() + G_s = G_s * (delta.unsqueeze(1) * delta.unsqueeze(0)) + D = D * delta + U_s = U / D # [B, m] — scaled lower bounds + + # --- Rho initialization --- + rho = max((G_s.norm("fro") / m).item(), 1e-6) + + # --- ADMM --- + L = torch.linalg.cholesky(G_s + rho * I_m) + + V = U_s.clone() # primal variable (scaled) + Z = U_s.clone() # auxiliary variable (scaled) + u = torch.zeros(B, m, dtype=torch.float64, device=device) # scaled dual variable + + eps_abs = eps_rel = 1e-7 + check_freq = round(m**0.5) + tau = 10.0 # adaptive-rho trigger threshold + + for k in range(2000): + Z_prev = Z + + # V-update: (G_s + rho*I) V = rho*(Z - u) + V = torch.cholesky_solve((rho * (Z - u)).T, L).T + + # Z-update: project onto {z : z >= U_s} + Z = (V + u).clamp(min=U_s) + + # Scaled dual update + primal_residual = V - Z + u = u + primal_residual + + if k % check_freq == 0: + primal_res = primal_residual.norm(torch.inf).item() + dual_res = (rho * (Z - Z_prev)).norm(torch.inf).item() + + tol_p = eps_abs + eps_rel * max( + V.norm(torch.inf).item(), + Z.norm(torch.inf).item(), + ) + tol_d = eps_abs + eps_rel * (rho * u).norm(torch.inf).item() + if primal_res < tol_p and dual_res < tol_d: + break + + # Adaptive rho: scale rho and rescale dual variable to maintain lambda = rho * u + if primal_res > tau * dual_res: + rho = min(rho * tau, 1e6) + u = u / tau + L = torch.linalg.cholesky(G_s + rho * I_m) + elif tau * primal_res < dual_res: + rho = max(rho / tau, 1e-6) + u = u * tau + L = torch.linalg.cholesky(G_s + rho * I_m) + + if (V - Z).norm(torch.inf) > 1e-3: + raise ValueError("Failed to solve the quadratic programming problem.") - return tensor.cpu().detach().numpy().astype(np.float64) + # Unscale: v_orig = D * v_scaled + return V * D diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/aggregation/_utils/test_dual_cone.py index 68a8a75d..1aae0093 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/aggregation/_utils/test_dual_cone.py @@ -1,10 +1,9 @@ -import numpy as np import torch -from pytest import mark, raises +from pytest import mark from torch.testing import assert_close from utils.tensors import rand_, randn_ -from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights +from torchjd.aggregation._utils.dual_cone import project_weights @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) @@ -34,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") + w = project_weights(u, G) dual_gap = w - u # Dual feasibility @@ -63,8 +62,8 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: G = J @ J.T u = rand_(shape[0]) - w = project_weights(u, G, "quadprog") - w_scaled = project_weights(u, scaling * G, "quadprog") + w = project_weights(u, G) + w_scaled = project_weights(u, scaling * G) assert_close(w_scaled, w) @@ -82,16 +81,7 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None: G = matrix @ matrix.T - W_tensor = project_weights(U_tensor, G, "quadprog") - W_matrix = project_weights(U_matrix, G, "quadprog") + W_tensor = project_weights(U_tensor, G) + W_matrix = project_weights(U_matrix, G) assert_close(W_matrix.reshape(shape), W_tensor) - - -def test_project_weight_vector_failure() -> None: - """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" - - large_J = np.random.randn(10, 100) * 1e5 - large_G = large_J @ large_J.T - with raises(ValueError): - _project_weight_vector(np.ones(10), large_G, "quadprog") diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 5bd0e71a..a0976d31 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -32,7 +32,7 @@ def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None: @mark.parametrize(["aggregator", "matrix"], typical_pairs) def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None: - assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07) + assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-05, rtol=2e-05) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) @@ -46,20 +46,13 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: def test_representations() -> None: - A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert ( - repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" - ) + A = DualProj(pref_vector=None, norm_eps=0.0001) + assert repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001)" assert str(A) == "DualProj" A = DualProj( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, - reg_eps=0.0001, - solver="quadprog", - ) - assert ( - repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" ) + assert repr(A) == "DualProj(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001)" assert str(A) == "DualProj([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d..082dcc8f 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -54,8 +54,6 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: upgrad_sum_weighting = UPGradWeighting( ones_((2,)), norm_eps=0.0, - reg_eps=0.0, - solver="quadprog", ) result = pc_grad_weighting(gramian) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 1859b662..c77b555c 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -33,7 +33,7 @@ def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None: @mark.parametrize(["aggregator", "matrix"], typical_pairs) def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None: - assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07) + assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=1e-05, rtol=1e-05) @mark.parametrize(["aggregator", "matrix"], typical_pairs) @@ -52,18 +52,13 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: def test_representations() -> None: - A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") - assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" + A = UPGrad(pref_vector=None, norm_eps=0.0001) + assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001)" assert str(A) == "UPGrad" A = UPGrad( pref_vector=torch.tensor([1.0, 2.0, 3.0], device="cpu"), norm_eps=0.0001, - reg_eps=0.0001, - solver="quadprog", - ) - assert ( - repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001, reg_eps=0.0001, " - "solver='quadprog')" ) + assert repr(A) == "UPGrad(pref_vector=tensor([1., 2., 3.]), norm_eps=0.0001)" assert str(A) == "UPGrad([1., 2., 3.])" diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 860f313d..1054f313 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -55,7 +55,7 @@ (AlignedMTL(), J_base, tensor([0.2133, 0.9673, 0.9673])), (ConFIG(), J_base, tensor([0.1588, 2.0706, 2.0706])), (Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])), - (DualProj(), J_base, tensor([0.5563, 1.1109, 1.1109])), + (DualProj(), J_base, tensor([0.5556, 1.1111, 1.1111])), (GradDrop(), J_base, tensor([6.0, 2.0, 2.0])), (IMTLG(), J_base, tensor([0.0767, 1.0000, 1.0000])), (Krum(n_byzantine=1, n_selected=4), J_Krum, tensor([1.2500, 0.7500, 1.5000])), @@ -65,7 +65,7 @@ (Random(), J_base, tensor([-2.6229, 1.0000, 1.0000])), (Sum(), J_base, tensor([2.0, 2.0, 2.0])), (TrimmedMean(trim_number=1), J_TrimmedMean, tensor([1.5000, 2.5000])), - (UPGrad(), J_base, tensor([0.2929, 1.9004, 1.9004])), + (UPGrad(), J_base, tensor([0.2924, 1.9006, 1.9006])), ] G_base = J_base @ J_base.T @@ -74,7 +74,7 @@ WEIGHTING_PARAMETRIZATIONS = [ (AlignedMTLWeighting(), G_base, tensor([0.5591, 0.4083])), (ConstantWeighting(tensor([1.0, 2.0])), G_base, tensor([1.0, 2.0])), - (DualProjWeighting(), G_base, tensor([0.6109, 0.5000])), + (DualProjWeighting(), G_base, tensor([0.6111, 0.5000])), (IMTLGWeighting(), G_base, tensor([0.5923, 0.4077])), (KrumWeighting(1, 4), G_Krum, tensor([0.2500, 0.2500, 0.0000, 0.2500, 0.2500])), (MeanWeighting(), G_base, tensor([0.5000, 0.5000])), @@ -82,7 +82,7 @@ (PCGradWeighting(), G_base, tensor([2.2222, 1.5789])), (RandomWeighting(), G_base, tensor([0.8623, 0.1377])), (SumWeighting(), G_base, tensor([1.0, 1.0])), - (UPGradWeighting(), G_base, tensor([1.1109, 0.7894])), + (UPGradWeighting(), G_base, tensor([1.1111, 0.7895])), ] try: