Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ authors = [
requires-python = ">=3.10"
dependencies = [
"torch>=2.3.0", # Problems before 2.4.0, especially with autogram.
"quadprog>=0.1.9, != 0.1.10", # Doesn't work before 0.1.9, 0.1.10 is yanked
"numpy>=1.21.2", # Does not work before 1.21. No python 3.10 wheel before 1.21.2.
"qpsolvers>=1.0.1", # Does not work before 1.0.1
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -96,8 +94,6 @@ plot = [
lower_bounds = [
"torch==2.3.0",
"numpy==1.21.2",
"quadprog==0.1.9",
"qpsolvers==1.0.1",
]

[project.optional-dependencies]
Expand Down
33 changes: 7 additions & 26 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix, normalize, regularize
from torchjd._linalg import PSDMatrix, normalize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.dual_cone import project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand All @@ -20,36 +20,27 @@ class DualProj(GramianWeightedAggregator):
:param pref_vector: The preference vector used to combine the rows. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
DualProjWeighting(pref_vector, norm_eps=norm_eps),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"norm_eps={self._norm_eps})"
)

def __str__(self) -> str:
Expand All @@ -64,29 +55,19 @@ class DualProjWeighting(Weighting[PSDMatrix]):
:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix, /) -> Tensor:
u = self.weighting(gramian)
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
w = project_weights(u, G, self.solver)
return w
G = normalize(gramian, self.norm_eps)
return project_weights(u, G)
32 changes: 7 additions & 25 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix, normalize, regularize
from torchjd._linalg import PSDMatrix, normalize

from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
from ._utils.dual_cone import project_weights
from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import Weighting
Expand All @@ -21,36 +21,27 @@ class UPGrad(GramianWeightedAggregator):
defaults to :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
UPGradWeighting(pref_vector, norm_eps=norm_eps),
)

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, "
f"norm_eps={self._norm_eps})"
)

def __str__(self) -> str:
Expand All @@ -65,29 +56,20 @@ class UPGradWeighting(Weighting[PSDMatrix]):
:param pref_vector: The preference vector to use. If not provided, defaults to
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`.
:param norm_eps: A small value to avoid division by zero when normalizing.
:param reg_eps: A small value to add to the diagonal of the gramian of the matrix. Due to
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
"""

def __init__(
self,
pref_vector: Tensor | None = None,
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix, /) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
W = project_weights(U, G, self.solver)
G = normalize(gramian, self.norm_eps)
W = project_weights(U, G)
return torch.sum(W, dim=0)
133 changes: 94 additions & 39 deletions src/torchjd/aggregation/_utils/dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,117 @@
from typing import Literal, TypeAlias

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]


def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
def project_weights(U: Tensor, G: Tensor) -> Tensor:
"""
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
rows of a matrix whose Gramian is provided.

:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
:param solver: The quadratic programming solver to use.
:return: A tensor of projection weights with the same shape as `U`.
"""

G_ = _to_array(G)
U_ = _to_array(U)
shape = U.shape
m = shape[-1]
batch_size = U.numel() // m

W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_)
# Cast to float64 for numerical stability
G64 = G.to(dtype=torch.float64)
U_flat = U.to(dtype=torch.float64).reshape(batch_size, m)

return torch.as_tensor(W, device=G.device, dtype=G.dtype)
W = _solve_batch_qp(U_flat, G64)

return W.reshape(shape).to(dtype=G.dtype)

def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray:

def _solve_batch_qp(U: Tensor, G: Tensor) -> Tensor:
r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
`\pi_J(J^T u) = J^T w`, with `\pi_J` defined in Equation 3 of [1].

By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program:
minimize v^T G v
subject to u \preceq v

Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.

:param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to
project.
:param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. It must be
symmetric and positive definite.
:param solver: The quadratic programming solver to use.
"""
Solves a batch of QPs sharing the same cost matrix using ADMM:

m = G.shape[0]
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver)
.. code-block:: text

if w is None: # This may happen when G has large values.
raise ValueError("Failed to solve the quadratic programming problem.")
minimize (1/2) v^T G v
subject to U[i] <= v (componentwise, for each row i of U)

Three improvements over basic ADMM ensure convergence on ill-conditioned Gramians:

return w
- **Ruiz equilibration** (5 iterations): symmetrically scales G to bring all rows and
columns to unit infinity norm, reducing the effective condition number.
- **Adaptive rho**: the ADMM penalty parameter is updated every ``sqrt(m)`` iterations
when primal and dual residuals are severely imbalanced, triggering a cheap re-factorization.
- **Normalized stopping criteria**: convergence is checked against absolute + relative
tolerances, matching the OSQP / lqp_py conventions.

:param U: Lower-bound matrix of shape ``[B, m]``.
:param G: Shared cost matrix of shape ``[m, m]``, symmetric positive definite.
"""

def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""
B, m = U.shape
device = G.device
I_m = torch.eye(m, dtype=torch.float64, device=device)

# --- Ruiz equilibration ---
# Build D such that G_s = diag(D) @ G @ diag(D) has all row/column inf-norms ≈ 1.
# Variable substitution: v_orig = D * v_scaled => U_scaled = U / D.
G_s = G.clone()
D = torch.ones(m, dtype=torch.float64, device=device)
for _ in range(5):
delta = G_s.abs().amax(dim=1).clamp(min=1e-10).rsqrt()
G_s = G_s * (delta.unsqueeze(1) * delta.unsqueeze(0))
D = D * delta
U_s = U / D # [B, m] — scaled lower bounds

# --- Rho initialization ---
rho = max((G_s.norm("fro") / m).item(), 1e-6)

# --- ADMM ---
L = torch.linalg.cholesky(G_s + rho * I_m)

V = U_s.clone() # primal variable (scaled)
Z = U_s.clone() # auxiliary variable (scaled)
u = torch.zeros(B, m, dtype=torch.float64, device=device) # scaled dual variable

eps_abs = eps_rel = 1e-7
check_freq = round(m**0.5)
tau = 10.0 # adaptive-rho trigger threshold

for k in range(2000):
Z_prev = Z

# V-update: (G_s + rho*I) V = rho*(Z - u)
V = torch.cholesky_solve((rho * (Z - u)).T, L).T

# Z-update: project onto {z : z >= U_s}
Z = (V + u).clamp(min=U_s)

# Scaled dual update
primal_residual = V - Z
u = u + primal_residual

if k % check_freq == 0:
primal_res = primal_residual.norm(torch.inf).item()
dual_res = (rho * (Z - Z_prev)).norm(torch.inf).item()

tol_p = eps_abs + eps_rel * max(
V.norm(torch.inf).item(),
Z.norm(torch.inf).item(),
)
tol_d = eps_abs + eps_rel * (rho * u).norm(torch.inf).item()
if primal_res < tol_p and dual_res < tol_d:
break

# Adaptive rho: scale rho and rescale dual variable to maintain lambda = rho * u
if primal_res > tau * dual_res:
rho = min(rho * tau, 1e6)
u = u / tau
L = torch.linalg.cholesky(G_s + rho * I_m)
elif tau * primal_res < dual_res:
rho = max(rho / tau, 1e-6)
u = u * tau
L = torch.linalg.cholesky(G_s + rho * I_m)

if (V - Z).norm(torch.inf) > 1e-3:
raise ValueError("Failed to solve the quadratic programming problem.")

return tensor.cpu().detach().numpy().astype(np.float64)
# Unscale: v_orig = D * v_scaled
return V * D
24 changes: 7 additions & 17 deletions tests/unit/aggregation/_utils/test_dual_cone.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import torch
from pytest import mark, raises
from pytest import mark
from torch.testing import assert_close
from utils.tensors import rand_, randn_

from torchjd.aggregation._utils.dual_cone import _project_weight_vector, project_weights
from torchjd.aggregation._utils.dual_cone import project_weights


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
Expand Down Expand Up @@ -34,7 +33,7 @@ def test_solution_weights(shape: tuple[int, int]) -> None:
G = J @ J.T
u = rand_(shape[0])

w = project_weights(u, G, "quadprog")
w = project_weights(u, G)
dual_gap = w - u

# Dual feasibility
Expand Down Expand Up @@ -63,8 +62,8 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None:
G = J @ J.T
u = rand_(shape[0])

w = project_weights(u, G, "quadprog")
w_scaled = project_weights(u, scaling * G, "quadprog")
w = project_weights(u, G)
w_scaled = project_weights(u, scaling * G)

assert_close(w_scaled, w)

Expand All @@ -82,16 +81,7 @@ def test_tensorization_shape(shape: tuple[int, ...]) -> None:

G = matrix @ matrix.T

W_tensor = project_weights(U_tensor, G, "quadprog")
W_matrix = project_weights(U_matrix, G, "quadprog")
W_tensor = project_weights(U_tensor, G)
W_matrix = project_weights(U_matrix, G)

assert_close(W_matrix.reshape(shape), W_tensor)


def test_project_weight_vector_failure() -> None:
"""Tests that `_project_weight_vector` raises an error when the input G has too large values."""

large_J = np.random.randn(10, 100) * 1e5
large_G = large_J @ large_J.T
with raises(ValueError):
_project_weight_vector(np.ones(10), large_G, "quadprog")
Loading
Loading