diff --git a/README.md b/README.md index 3e69af448..dfce350f9 100644 --- a/README.md +++ b/README.md @@ -451,3 +451,5 @@ Artificial Intelligence. [81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research. + +[83] Genans, F., Godichon-Baggioni, A., Vialard, F. X., & Wintenberger, O. (2026). [Decreasing Entropic Regularization Averaged Gradient for Semi-Discrete Optimal Transport](https://proceedings.neurips.cc/paper_files/paper/2025/file/d7efa12e98f5e0dd8b4f48cd60b4e3aa-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 38, 146913-146949. diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..be95d5b57 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Add SGD based semi-discrete OT solver in `ot.semidiscrete` and a gallery example. (PR #812) #### Closed issues diff --git a/examples/others/plot_semidiscrete.py b/examples/others/plot_semidiscrete.py new file mode 100644 index 000000000..69d475c31 --- /dev/null +++ b/examples/others/plot_semidiscrete.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +r""" +================================== +Semi-discrete OT: a toy 2D problem +================================== + +This example shows the :mod:`ot.semidiscrete` solver on a small 2D problem: +a uniform source on :math:`[0, 1]^2` and 15 random target atoms with uniform +weights. With so few atoms the Laguerre cells can be drawn by brute force on +a grid. + +We call :func:`ot.semidiscrete.solve_semidiscrete` with its default +arguments: the underlying algorithm is **Projected Averaged SGD**, and the +default ``decreasing_reg=True`` adds the **DRAG** entropic-regularization +schedule of [83]_, which improves convergence. + +For the returned potential :math:`g` we report: + +- the empirical Laguerre-cell masses (mean and max absolute deviation from + :math:`1/15`); +- the semi-dual objective + :math:`\langle g, b\rangle + \mathbb{E}_X[\varphi_g(X)]` estimated by + Monte Carlo, where the c-transform + :math:`\varphi_g(x) = \min_j\big(c(x, y_j) - g_j\big)` is computed by + :func:`ot.semidiscrete.c_transform`. The solver **maximises** this + objective. + +.. [83] Genans, F., Godichon-Baggioni, A., Vialard, F.-X., Wintenberger, O. + (2025). *Decreasing Entropic Regularization Averaged Gradient for + Semi-Discrete Optimal Transport.* NeurIPS 2025. +""" + +# Author: Ferdinand Genans +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pyplot as plt + +from ot.semidiscrete import ( + solve_semidiscrete, + atom_weights, + c_transform, +) + +############################################################################## +# Toy 2D problem +# -------------- + +rng = np.random.default_rng(42) + + +def source_sampler(batch_size): + return rng.random((batch_size, 2)) + + +n_atoms = 15 +target_positions = 0.1 + 0.8 * np.random.default_rng(0).random((n_atoms, 2)) + + +def plot_laguerre_cells(target, g, ax, title, resolution=300): + xs = np.linspace(0, 1, resolution) + ys = np.linspace(0, 1, resolution) + XX, YY = np.meshgrid(xs, ys) + grid = np.stack([XX.ravel(), YY.ravel()], axis=1) + labels = atom_weights(target, grid, g, reg=0.0).argmax(axis=1) + image = labels.reshape(resolution, resolution) + cmap = plt.get_cmap("tab20", target.shape[0]) + ax.imshow( + image, + origin="lower", + extent=(0, 1, 0, 1), + cmap=cmap, + alpha=0.55, + vmin=-0.5, + vmax=target.shape[0] - 0.5, + interpolation="nearest", + ) + # Target points share the colour of their Laguerre cell. + ax.scatter( + target[:, 0], + target[:, 1], + s=80, + c=[cmap(i) for i in range(target.shape[0])], + edgecolor="black", + linewidths=1.2, + zorder=3, + ) + ax.set_title(title) + ax.set_aspect("equal") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + +############################################################################## +# Solve and visualise +# ------------------- +# +# A single call to :func:`solve_semidiscrete` runs DRAG with the default +# arguments (``decreasing_reg=True``). We show the initial Voronoi cells +# (:math:`g = 0`) next to the Laguerre cells at the optimum. +# In this problem, the maximum cost between samples is 1.0, so we pass it as +# ``max_cost=1.0``. Knowing this bound, the potential values are clipped to +# [-max_cost, max_cost], where it is known that an optimal potential lies ([83]_, Lemma 1), +# which speeds up convergence. +g_drag = solve_semidiscrete( + target_positions, + source_sampler, + n_iter=20_000, + batch_size=16, + max_cost=1.0, +) + +fig, axes = plt.subplots(1, 2, figsize=(11, 5.5)) +plot_laguerre_cells(target_positions, np.zeros(n_atoms), axes[0], "Voronoi (g = 0)") +plot_laguerre_cells(target_positions, g_drag, axes[1], "DRAG") +plt.tight_layout() +plt.show() + + +############################################################################## +# Cell masses and Monte Carlo cost +# -------------------------------- +# +# At the optimum each Laguerre cell should carry mass :math:`1/15`. We report +# the empirical mass error and the semi-dual objective +# +# .. math:: +# \mathcal{S}(g) = \langle g, b\rangle + \mathbb{E}_X[\varphi_g(X)] +# +# estimated by Monte Carlo. The solver maximises :math:`\mathcal{S}`. + + +def cell_masses(target, g, sampler, n_samples=100_000): + labels = atom_weights(target, sampler(n_samples), g, reg=0.0).argmax(axis=1) + counts = np.bincount(labels, minlength=target.shape[0]) + return counts / n_samples + + +def mc_cost(target, g, sampler, n_samples=100_000): + b = np.full(target.shape[0], 1.0 / target.shape[0]) + samples = sampler(n_samples) + return float(g @ b + c_transform(target, samples, g, reg=0.0).mean()) + + +target_mass = 1.0 / n_atoms +m_drag = cell_masses(target_positions, g_drag, source_sampler) +cost_drag = mc_cost(target_positions, g_drag, source_sampler) + +print(f"Target mass per cell: {target_mass:.4f}") +print( + f"DRAG — mean abs. mass error: " + f"{np.mean(np.abs(m_drag - target_mass)):.4f}" + f" max: {np.max(np.abs(m_drag - target_mass)):.4f}" + f" semi-dual cost (MC): {cost_drag:.5f}" +) diff --git a/ot/__init__.py b/ot/__init__.py index 75f17fed6..dd4e068b1 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import gaussian from . import lowrank from . import gmm +from . import semidiscrete # OT functions from .lp import ( @@ -145,6 +146,7 @@ "factored", "lowrank", "gmm", + "semidiscrete", "binary_search_circle", "wasserstein_circle", "semidiscrete_wasserstein2_unif_circle", diff --git a/ot/semidiscrete.py b/ot/semidiscrete.py new file mode 100644 index 000000000..64ed5d1d1 --- /dev/null +++ b/ot/semidiscrete.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +""" +Semi-discrete optimal transport: continuous source, discrete target. + +Backend-agnostic semi-dual solver based on the Projected Averaged SGD +of [1]_, with an optional decreasing entropic regularization schedule +(DRAG, [2]_). Works with any backend supported by :mod:`ot.backend` +(NumPy, PyTorch, JAX, CuPy, TensorFlow). + +References +---------- +.. [1] Genans, Godichon-Baggioni, Vialard, Wintenberger (2025). + "Stochastic Optimization in Semi-Discrete Optimal Transport: + Convergence Analysis and Minimax Rate." NeurIPS 2025. +.. [2] Genans, Godichon-Baggioni, Vialard, Wintenberger (2025). + "Decreasing Entropic Regularization Averaged Gradient for + Semi-Discrete Optimal Transport." NeurIPS 2025. +""" + +# Author: Ferdinand Genans +# +# License: MIT License + +import math + +import numpy as np + +from .backend import get_backend + + +def _quadratic_cost(x, y, nx): + r"""Default cost: :math:`\tfrac{1}{2} \|x - y\|^2`.""" + x_sq = nx.sum(x**2, axis=1)[:, None] + y_sq = nx.sum(y**2, axis=1)[None, :] + cross = nx.einsum("ij,kj->ik", x, y) + return 0.5 * (x_sq + y_sq - 2.0 * cross) + + +def _setup(target_positions, target_weights, cost): + """Resolve backend, default weights and default cost.""" + nx = get_backend(target_positions) + m = target_positions.shape[0] + if target_weights is None: + target_weights = nx.full((m,), 1.0 / m, type_as=target_positions) + if cost is None: + + def cost(x, y): + return _quadratic_cost(x, y, nx) + + return nx, m, target_weights, nx.log(target_weights), cost + + +def _atom_weights(score, reg, log_b, nx): + """Row-stochastic weights ``(batch, m)`` from ``score = g - C``. + + Softmax of ``score / reg + log_b`` when ``reg > 0``, one-hot of + ``argmax(score, axis=1)`` when ``reg == 0``. + """ + if reg > 0: + log_w = score / reg + log_b[None, :] + log_w = log_w - nx.logsumexp(log_w, axis=1)[:, None] + return nx.exp(log_w) + m = score.shape[1] + idx = nx.argmax(score, axis=1) + arange_m = nx.from_numpy(np.arange(m), type_as=score) + mask = idx[:, None] == arange_m[None, :] + one = nx.full((1,), 1.0, type_as=score) + zero = nx.full((1,), 0.0, type_as=score) + return nx.where(mask, one, zero) + + +def atom_weights( + target_positions, + source_samples, + semi_dual_potential, + target_weights=None, + cost=None, + reg=0.0, +): + r"""Row-stochastic atom-assignment weights induced by ``semi_dual_potential``. + + Returns an array ``w`` of shape ``(n_samples, n_atoms)`` such that + ``w[i, j]`` is the (entropic) probability that sample ``x_i`` is + transported to atom ``y_j``. + """ + nx, _, _, log_b, cost_fn = _setup(target_positions, target_weights, cost) + score = semi_dual_potential[None, :] - cost_fn(source_samples, target_positions) + return _atom_weights(score, reg, log_b, nx) + + +def ot_map( + target_positions, + source_samples, + semi_dual_potential, + target_weights=None, + cost=None, + reg=0.0, +): + r"""Transport map :math:`T(x) = \sum_j w_j(x)\, y_j` induced by the potential.""" + w = atom_weights( + target_positions, + source_samples, + semi_dual_potential, + target_weights=target_weights, + cost=cost, + reg=reg, + ) + return w @ target_positions + + +def c_transform( + target_positions, + source_samples, + semi_dual_potential, + target_weights=None, + cost=None, + reg=0.0, +): + r"""Pointwise (entropic) c-transform of ``semi_dual_potential``. + + - ``reg == 0``: :math:`\varphi_g(x) = \min_j\, c(x, y_j) - g_j`. + - ``reg > 0``: :math:`\varphi_g(x) = -\varepsilon \log \sum_j b_j + \exp\!\big((g_j - c(x, y_j))/\varepsilon\big)`. + """ + nx, _, _, log_b, cost_fn = _setup(target_positions, target_weights, cost) + score = semi_dual_potential[None, :] - cost_fn(source_samples, target_positions) + if reg == 0: + return -nx.max(score, axis=1) + return -reg * nx.logsumexp(score / reg + log_b[None, :], axis=1) + + +def solve_semidiscrete( + target_positions, + source_sampler, + target_weights=None, + cost=None, + reg=0.0, + n_iter=10_000, + batch_size=32, + lr0=None, + lr_exponent=2.0 / 3.0, + init_potential=None, + decreasing_reg=True, + decreasing_reg_initial_eps=0.1, + decreasing_reg_exponent=0.5, + max_cost=None, + polyak_average=True, + log=False, +): + r"""Solve semi-discrete OT by Polyak-averaged SGD on the semi-dual. + + Maximizes the semi-dual :math:`g \mapsto \langle g, b \rangle + \mathbb{E}_X[\varphi_g(X)]` + by averaged stochastic gradient ascent with projection and decreasing + regularization, which corresponds to the DRAG algorithm [1]_. + Here :math:`\varphi_g` denotes the (entropic) c-transform of :math:`g`, + + .. math:: + \varphi_g(x) = \begin{cases} + \min_j \big(c(x, y_j) - g_j\big) & \text{if } \mathrm{reg} = 0, \\ + -\varepsilon \log \sum_j b_j \exp\!\big((g_j - c(x, y_j))/\varepsilon\big) + & \text{if } \mathrm{reg} = \varepsilon > 0, + \end{cases} + + cf. :func:`c_transform`. + + With ``decreasing_reg=True`` the regularization at iteration ``t`` is + :math:`\varepsilon_t = \max(\text{reg},\, \varepsilon_0 / t^\alpha)` — large + at first for smoothness, then annealed towards ``reg``. This is the + DRAG schedule of [1]_. + + Parameters + ---------- + target_positions : array-like, shape (n_atoms, d) + Positions of the target atoms. The backend of this array drives + all subsequent computations. + source_sampler : callable + ``source_sampler(batch_size)`` returns a ``(batch_size, d)`` array + of source samples, in the same backend as ``target_positions``. + target_weights : array-like, shape (n_atoms,), optional + Atom weights. Defaults to uniform. + cost : callable, optional + ``cost(x, y)`` returns the ``(n_samples, n_atoms)`` cost matrix. + Defaults to ``0.5 * ||x - y||^2``. + reg : float, default=0.0 + Entropic regularization (target value when ``decreasing_reg=True``). + n_iter : int, default=10000 + batch_size : int, default=32 + lr0 : float, optional + Initial learning rate. Defaults to ``sqrt(n_atoms * batch_size)``. + lr_exponent : float, default=2/3 + Step size decays as ``lr0 / t**lr_exponent``. + init_potential : array-like, shape (n_atoms,), optional + Starting iterate; defaults to zero. Not mutated. + decreasing_reg : bool, default=True + Enable the DRAG decreasing-regularization schedule. + decreasing_reg_initial_eps : float, default=0.1 + Initial regularization in the DRAG schedule. + decreasing_reg_exponent : float, default=0.5 + Decay exponent of the DRAG schedule. + max_cost : float, optional + If given, clip each iterate to ``[-max_cost, max_cost]``. + polyak_average : bool, default=True + If True, return the uniform average of the iterates; else the last. + log : bool, default=False + If True, also return a small ``dict`` with the last iterate. + + Returns + ------- + semi_dual_potential : array, shape (n_atoms,) + info : dict, optional + Returned only when ``log=True``. + + References + ---------- + .. [1] Genans, Godichon-Baggioni, Vialard, Wintenberger (2025). + "Decreasing Entropic Regularization Averaged Gradient for + Semi-Discrete Optimal Transport." NeurIPS 2025. + + Examples + -------- + >>> import numpy as np + >>> from ot.semidiscrete import solve_semidiscrete + >>> rng = np.random.default_rng(0) + >>> target = np.linspace(0.0, 1.0, 10).reshape(-1, 1) + >>> g = solve_semidiscrete( + ... target, lambda b: rng.random((b, 1)), + ... n_iter=500, batch_size=32, max_cost=1.0, + ... ) + """ + nx, m, b, log_b, cost_fn = _setup(target_positions, target_weights, cost) + + if init_potential is None: + g = nx.zeros((m,), type_as=target_positions) + else: + g = init_potential + nx.zeros((m,), type_as=target_positions) + + if lr0 is None: + lr0 = math.sqrt(m * batch_size) + + g_avg = nx.zeros((m,), type_as=target_positions) if polyak_average else None + + for t in range(1, n_iter + 1): + if decreasing_reg: + reg_t = max(reg, decreasing_reg_initial_eps / (t**decreasing_reg_exponent)) + else: + reg_t = reg + + x = source_sampler(batch_size) + score = g[None, :] - cost_fn(x, target_positions) + w = _atom_weights(score, reg_t, log_b, nx) + grad = nx.mean(w, axis=0) - b + + lr_t = lr0 / (t**lr_exponent) + g = g - lr_t * grad + if max_cost is not None: + g = nx.clip(g, -max_cost, max_cost) + if polyak_average: + g_avg = g_avg + (g - g_avg) / t + + result = g_avg if polyak_average else g + if log: + return result, { + "n_iter": n_iter, + "batch_size": batch_size, + "max_cost": max_cost, + "polyak_average": polyak_average, + "last_potential": g, + } + return result + + +__all__ = [ + "solve_semidiscrete", + "atom_weights", + "ot_map", + "c_transform", +] diff --git a/test/test_semidiscrete.py b/test/test_semidiscrete.py new file mode 100644 index 000000000..3ff048a1b --- /dev/null +++ b/test/test_semidiscrete.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +"""Tests for ``ot/semidiscrete.py``. + +We rely on three small toy problems whose optimal semi-dual potential is +known in closed form, and check that the solver converges close to that +optimum. All tests run across every POT backend (via the ``nx`` fixture). +""" + +# License: MIT License + +import numpy as np +import pytest + +from ot.semidiscrete import ( + atom_weights, + c_transform, + ot_map, + solve_semidiscrete, +) + +N_ITER = 2_000 +BATCH_SIZE = 16 +TOLERANCE = 0.05 + + +# --------------------------------------------------------------------- +# Three toy problems with known optimal potentials. +# Each builder returns numpy arrays so we can lift them onto any backend. +# --------------------------------------------------------------------- + + +def regular_grid_problem(): + """Uniform source on ``[0, 1]^3``, 10 target atoms regularly placed on axis 0. + + By symmetry the optimal centred potential is zero on every atom. + """ + m, d = 10, 3 + target = np.zeros((m, d)) + target[:, 0] = (np.arange(m) + 0.5) / m + target[:, 1:] = 0.5 + weights = np.full(m, 1.0 / m) + optimal = np.zeros(m) + return target, weights, optimal, 1.0, d, "uniform_cube" + + +def nonuniform_weights_problem(): + """Uniform source on ``[0, 1]^3``, fixed support, nonuniform weights. + + The weights were computed by Monte Carlo so that the optimal centred + potential is exactly ``optimal_potential`` (used as a regression target). + """ + d = 3 + target = np.array( + [ + [0.54488318, 0.4236548, 0.64589411], + [0.77815675, 0.87001215, 0.97861834], + [0.11827443, 0.63992102, 0.14335329], + [0.94466892, 0.52184832, 0.41466194], + ] + ) + weights = np.array([0.3806643, 0.012264, 0.5503486, 0.0567231]) + optimal_potential = np.array([0.77423369, 0.61209572, 0.94374808, 0.6818203]) + return target, weights, optimal_potential, 1.0, d, "uniform_cube" + + +def shifted_1d_problem(): + r"""Uniform source on ``[delta, 1 + delta]``, atoms at ``k/m`` on the line. + + For this 1D problem with uniform target weights the optimal potential + has the closed form + + .. math:: + g^*_j = j \left( \frac{1}{2 m^2} - \frac{\delta}{m} \right), + \qquad j = 0, \dots, m - 1. + + See Appendix E.1 in "Stochastic Optimization in Semi-Discrete Optimal + Transport: Convergence Analysis and Minimax Rate", Genans et al. + (NeurIPS 2025). + """ + m, delta = 10, 0.5 + target = np.linspace(1 / m, 1.0, m).reshape(m, 1) + weights = np.full(m, 1.0 / m) + optimal = np.zeros(m) + for j in range(1, m): + optimal[j] = optimal[j - 1] + 1 / (2 * m * m) - delta / m + return target, weights, optimal, 1.0 + delta, 1, ("shifted_1d", delta) + + +ALL_PROBLEMS = [ + pytest.param(regular_grid_problem, id="regular_grid"), + pytest.param(nonuniform_weights_problem, id="nonuniform_weights"), + pytest.param(shifted_1d_problem, id="shifted_1d"), +] + + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def make_sampler(kind, d, nx, target): + """Return a backend-aware sampler with a fixed numpy RNG inside.""" + rng = np.random.default_rng(0) + if kind == "uniform_cube": + + def sampler(b): + return nx.from_numpy(rng.random((b, d)), type_as=target) + + else: + # ("shifted_1d", delta) + delta = kind[1] + + def sampler(b): + return nx.from_numpy(delta + rng.random((b, d)), type_as=target) + + return sampler + + +def centered_l2_error(estimated, reference): + """L2 distance between the centred (mean-zero) potentials. + + The semi-dual is invariant to a global additive constant, so we factor that out. + """ + estimated = estimated - estimated.mean() + reference = reference - reference.mean() + return float(np.linalg.norm(estimated - reference)) + + +def lift(nx, target_np, weights_np): + """Move numpy ``target`` and ``weights`` arrays onto backend ``nx``.""" + target = nx.from_numpy(target_np) + weights = nx.from_numpy(weights_np, type_as=target) + return target, weights + + +# --------------------------------------------------------------------- +# Convergence on every backend +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize("build_problem", ALL_PROBLEMS) +def test_solve_converges(nx, build_problem): + """Plain SGD (no decreasing reg) reaches the optimum on every toy problem.""" + target_np, weights_np, optimal, max_cost, d, kind = build_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + g = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=N_ITER, + batch_size=BATCH_SIZE, + decreasing_reg=False, + max_cost=max_cost, + ) + err = centered_l2_error(nx.to_numpy(g), optimal) + assert err < TOLERANCE, f"err={err:.4f}" + + +@pytest.mark.parametrize("build_problem", ALL_PROBLEMS) +def test_drag_converges(nx, build_problem): + """DRAG (decreasing entropic reg) reaches the optimum on every toy problem.""" + target_np, weights_np, optimal, max_cost, d, kind = build_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + g = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=N_ITER, + batch_size=BATCH_SIZE, + decreasing_reg=True, + max_cost=max_cost, + ) + err = centered_l2_error(nx.to_numpy(g), optimal) + assert err < TOLERANCE, f"err={err:.4f}" + + +# --------------------------------------------------------------------- +# Entropic regime +# --------------------------------------------------------------------- + + +def test_entropic_solver_runs(nx): + """In the entropic regime, the solver and ``c_transform`` produce finite values.""" + target_np, weights_np, _, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + g = solve_semidiscrete( + target, + sampler, + target_weights=weights, + reg=0.05, + n_iter=N_ITER, + batch_size=BATCH_SIZE, + max_cost=max_cost, + ) + samples = sampler(64) + phi = c_transform(target, samples, g, target_weights=weights, reg=0.05) + assert np.isfinite(nx.to_numpy(g)).all() + assert np.isfinite(nx.to_numpy(phi)).all() + + +# --------------------------------------------------------------------- +# Custom cost +# --------------------------------------------------------------------- + + +def test_custom_quadratic_cost_matches_default(nx): + """A user-supplied quadratic cost reaches the same optimum as the default.""" + target_np, weights_np, optimal, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + def cost(x, y): + diff = x[:, None, :] - y[None, :, :] + return 0.5 * nx.sum(diff**2, axis=2) + + g = solve_semidiscrete( + target, + sampler, + target_weights=weights, + cost=cost, + n_iter=N_ITER, + batch_size=BATCH_SIZE, + max_cost=max_cost, + ) + err = centered_l2_error(nx.to_numpy(g), optimal) + assert err < TOLERANCE, f"err={err:.4f}" + + +# --------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize("reg", [0.0, 0.1]) +def test_atom_weights_are_row_stochastic(nx, reg): + """``atom_weights`` returns nonnegative weights that sum to 1 per row.""" + target_np, weights_np, _, _, d, kind = nonuniform_weights_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + samples = sampler(32) + g = nx.zeros((target_np.shape[0],), type_as=target) + w = atom_weights(target, samples, g, target_weights=weights, reg=reg) + w_np = nx.to_numpy(w) + assert w_np.shape == (32, target_np.shape[0]) + assert (w_np >= 0).all() + np.testing.assert_allclose(w_np.sum(axis=1), 1.0, atol=1e-10) + + +def test_ot_map_shape_and_finiteness(nx): + """``ot_map`` returns finite values with the source-sample shape.""" + target_np, weights_np, _, _, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + samples = sampler(16) + g = nx.zeros((target_np.shape[0],), type_as=target) + transported = ot_map(target, samples, g, target_weights=weights) + transported_np = nx.to_numpy(transported) + samples_np = nx.to_numpy(samples) + assert transported_np.shape == samples_np.shape + assert np.isfinite(transported_np).all() + + +def test_c_transform_minimum_for_zero_potential(nx): + """At ``g = 0``, ``phi_g(x) = -max_j(-c(x, y_j)) = min_j c(x, y_j)``.""" + target_np, weights_np, _, _, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + samples = sampler(8) + g = nx.zeros((target_np.shape[0],), type_as=target) + phi = c_transform(target, samples, g, target_weights=weights) + samples_np = nx.to_numpy(samples) + target_np = nx.to_numpy(target) + diff = samples_np[:, None, :] - target_np[None, :, :] + expected = 0.5 * (diff**2).sum(axis=2).min(axis=1) + np.testing.assert_allclose(nx.to_numpy(phi), expected, atol=1e-10) + + +# --------------------------------------------------------------------- +# Solver options +# --------------------------------------------------------------------- + + +def test_warm_start_converges(nx): + """Splitting one run into two warm-started halves still converges.""" + target_np, weights_np, optimal, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + half = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=N_ITER // 2, + batch_size=BATCH_SIZE, + max_cost=max_cost, + ) + g = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=N_ITER // 2, + batch_size=BATCH_SIZE, + init_potential=half, + max_cost=max_cost, + ) + err = centered_l2_error(nx.to_numpy(g), optimal) + assert err < TOLERANCE, f"err={err:.4f}" + + +def test_init_potential_is_not_mutated(nx): + """The ``init_potential`` array passed by the caller is left intact.""" + target_np, weights_np, _, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + init_np = np.full(target_np.shape[0], 0.5) + init = nx.from_numpy(init_np, type_as=target) + snapshot = nx.to_numpy(init).copy() + + solve_semidiscrete( + target, + sampler, + target_weights=weights, + init_potential=init, + n_iter=10, + batch_size=4, + max_cost=max_cost, + ) + np.testing.assert_array_equal(nx.to_numpy(init), snapshot) + + +def test_projection_clamps_last_iterate(nx): + """With ``max_cost=b``, every coordinate of the last iterate lies in ``[-b, b]``.""" + target_np, weights_np, _, _, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + bound = 0.05 + _, info = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=300, + batch_size=4, + max_cost=bound, + log=True, + ) + last = nx.to_numpy(info["last_potential"]) + assert np.abs(last).max() <= bound + 1e-10 + + +def test_polyak_average_off_returns_last_iterate(nx): + """With ``polyak_average=False`` the returned potential equals the last iterate.""" + target_np, weights_np, _, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + final, info = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=20, + batch_size=4, + polyak_average=False, + max_cost=max_cost, + log=True, + ) + np.testing.assert_array_equal( + nx.to_numpy(final), nx.to_numpy(info["last_potential"]) + ) + + +def test_log_returns_metadata(nx): + """``log=True`` returns an info dict with the expected fields.""" + target_np, weights_np, _, max_cost, d, kind = regular_grid_problem() + target, weights = lift(nx, target_np, weights_np) + sampler = make_sampler(kind, d, nx, target) + + g, info = solve_semidiscrete( + target, + sampler, + target_weights=weights, + n_iter=50, + batch_size=4, + max_cost=max_cost, + log=True, + ) + assert nx.to_numpy(g).shape == (target_np.shape[0],) + assert info["n_iter"] == 50 + assert info["batch_size"] == 4 + assert info["max_cost"] == max_cost