Skip to content

Commit 4593089

Browse files
author
peng.li24
committed
fix(linalg.inv): bit-exact float32 via float64 DGESV path
OpenBLAS sgesv_64_ gives 1-ULP-off results vs numpy on this build. numpy.linalg.inv for float32 is bit-equivalent to: float32 → float64 → dgesv_64_ → float32 which we now follow. Both dtypes are IEEE-754 bit-identical to numpy. - blas_bridge.h: replace LAPACKE getrf+getri with DGESV for both types (float32 promoted to f64 → dgesv_64_ → demoted to f32) - linalg.h: remove -0.0→+0.0 normalisation (no longer needed with DGESV) - test_all.py: assert_allclose → assert_bit_aligned for all inv tests - README: add linalg.inv to alignment table
1 parent a110b3a commit 4593089

4 files changed

Lines changed: 75 additions & 69 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ Two backends, same API — choose with `cmake -DNUMPYCPP_STD_ONLY=ON/OFF`.
230230
| **Norm** | `numpy.linalg.norm` (scalar + axis) || 〜 0–1 ULP |
231231
| **Matmul** | `numpy.matmul` (2-D, 1-D×2-D, 2-D×1-D, batched 3-D) || 〜 0–2 ULP |
232232
| **Einsum** | `ij,ij→i` `ij,jk→ik` `bij,bjk→bik` and all 2-operand patterns || 〜 0–2 ULP |
233+
| **Matrix inverse** | `numpy.linalg.inv` (N×N) || 〜 0–2 ULP |
233234

234235
> **bitexact backend**: transcendentals resolved via `dlsym` from numpy's
235236
> `_multiarray_umath.so` — same `npy_exp`/`npy_log` kernels numpy uses, with

numpycpp/detail/blas_bridge.h

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Use #include "numpycpp/numpy.h" instead."
3838
#endif
3939
4040
#include <cstdint>
41+
#include <cstring>
4142
#include <cmath>
4243
#include <dlfcn.h>
4344
#include <fstream>
@@ -233,46 +234,77 @@ inline void blas_dgemm(const double* A, const double* B, double* C,
233234
}
234235
235236
// ============================================================================
236-
// LAPACK — LU factorisation + matrix inverse (numpy.linalg.inv)
237+
// LAPACK — matrix inverse (numpy.linalg.inv) via Fortran DGESV
237238
// ============================================================================
238-
// numpy.linalg.inv uses LAPACKE (C interface) routing through OpenBLAS ILP64.
239-
// LAPACKE internally handles row→column conversion and workspace allocation,
240-
// producing the exact same floating-point rounding path as numpy.
239+
// numpy.linalg.inv calls LAPACK ?gesv (solve A·X = I). DGESV fuses LU
240+
// factorisation + forward/back substitution in a single kernel.
241241
//
242-
// LAPACKE function signatures (ILP64, return info as int64_t):
243-
// LAPACKE_sgetrf64_(layout, m, n, a, lda, ipiv)
244-
// LAPACKE_sgetri64_(layout, n, a, lda, ipiv)
245-
// layout = 101 (LAPACK_ROW_MAJOR)
242+
// On this OpenBLAS build, sgesv_64_ produces 1‑ULP differences vs numpy for
243+
// float32 inputs. NumPy's float32 inv is bit‑equivalent to: promote to
244+
// float64 → dgesv → demote to float32. We follow that same path so both
245+
// dtypes are IEEE‑754 bit‑identical to numpy.
246+
//
247+
// Fortran DGESV signature (ILP64, _64_ suffix):
248+
// dgesv_64_(int64_t *N, int64_t *NRHS, double *A, int64_t *LDA,
249+
// int64_t *IPIV, double *B, int64_t *LDB, int64_t *INFO);
250+
251+
using dgesv64_fn = void(int64_t*, int64_t*, double*, int64_t*,
252+
int64_t*, double*, int64_t*, int64_t*);
246253
247-
using LAPACKE_sgetrf64_fn = int64_t(int, int64_t, int64_t, float*, int64_t, int64_t*);
248-
using LAPACKE_dgetrf64_fn = int64_t(int, int64_t, int64_t, double*, int64_t, int64_t*);
249-
using LAPACKE_sgetri64_fn = int64_t(int, int64_t, float*, int64_t, const int64_t*);
250-
using LAPACKE_dgetri64_fn = int64_t(int, int64_t, double*, int64_t, const int64_t*);
254+
/// Fortran DGESV-based matrix inverse. Matches numpy.linalg.inv exactly
255+
/// — same Fortran symbol, same ILP64 ABI, same memory layout.
256+
template<typename T> inline bool blas_gesv_inv(T* A, size_t N);
251257
252-
/// LAPACKE-based matrix inverse (C interface, RowMajor).
253-
/// Uses ?getrf (LU factorisation) + ?getri (inverse from LU).
254-
/// Matches numpy.linalg.inv exactly — same LAPACKE path, same ILP64 ABI.
255-
inline bool blas_sinv(float* A, size_t N) {
256-
static auto getrf = (LAPACKE_sgetrf64_fn*)resolve_blas("LAPACKE_sgetrf64_");
257-
static auto getri = (LAPACKE_sgetri64_fn*)resolve_blas("LAPACKE_sgetri64_");
258-
if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
258+
template<> inline bool blas_gesv_inv<float>(float* A, size_t N) {
259+
// numpy.linalg.inv for float32 produces the same bits as:
260+
// float32 → float64 → dgesv → float32
261+
// (OpenBLAS sgesv_64_ gives 1-ULP-off results vs numpy on this build;
262+
// the float64 path is bit-identical for both types.)
263+
static auto gesv = (dgesv64_fn*)resolve_blas("dgesv_64_");
264+
if (__builtin_expect(gesv == nullptr, 0)) return false;
259265
int64_t n = static_cast<int64_t>(N);
260266
auto ipiv = std::make_unique<int64_t[]>(N);
261-
int64_t info = getrf(101, n, n, A, n, ipiv.get());
267+
// Double-precision work buffers (column-major)
268+
auto A_col = std::make_unique<double[]>(N * N);
269+
auto B_col = std::make_unique<double[]>(N * N);
270+
// Promote A row-major → A_col column-major (float→double)
271+
for (size_t i = 0; i < N; ++i)
272+
for (size_t j = 0; j < N; ++j)
273+
A_col[j*N + i] = static_cast<double>(A[i*N + j]);
274+
// B = identity (column-major, double)
275+
std::memset(B_col.get(), 0, N * N * sizeof(double));
276+
for (size_t i = 0; i < N; ++i)
277+
B_col[i + i*N] = 1.0;
278+
int64_t nrhs = n, lda = n, ldb = n, info = 0;
279+
gesv(&n, &nrhs, A_col.get(), &lda, ipiv.get(), B_col.get(), &ldb, &info);
262280
if (info != 0) return false;
263-
info = getri(101, n, A, n, ipiv.get());
264-
return info == 0;
281+
// Demote solution back to float32 row-major
282+
for (size_t i = 0; i < N; ++i)
283+
for (size_t j = 0; j < N; ++j)
284+
A[i*N + j] = static_cast<float>(B_col[j*N + i]);
285+
return true;
265286
}
266-
inline bool blas_dinv(double* A, size_t N) {
267-
static auto getrf = (LAPACKE_dgetrf64_fn*)resolve_blas("LAPACKE_dgetrf64_");
268-
static auto getri = (LAPACKE_dgetri64_fn*)resolve_blas("LAPACKE_dgetri64_");
269-
if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
287+
288+
template<> inline bool blas_gesv_inv<double>(double* A, size_t N) {
289+
static auto gesv = (dgesv64_fn*)resolve_blas("dgesv_64_");
290+
if (__builtin_expect(gesv == nullptr, 0)) return false;
270291
int64_t n = static_cast<int64_t>(N);
271292
auto ipiv = std::make_unique<int64_t[]>(N);
272-
int64_t info = getrf(101, n, n, A, n, ipiv.get());
293+
auto A_col = std::make_unique<double[]>(N * N);
294+
auto B_col = std::make_unique<double[]>(N * N);
295+
for (size_t i = 0; i < N; ++i)
296+
for (size_t j = 0; j < N; ++j)
297+
A_col[j*N + i] = A[i*N + j];
298+
for (size_t i = 0; i < N; ++i)
299+
for (size_t j = 0; j < N; ++j)
300+
B_col[i + j*N] = (i == j) ? 1.0 : 0.0;
301+
int64_t nrhs = n, lda = n, ldb = n, info = 0;
302+
gesv(&n, &nrhs, A_col.get(), &lda, ipiv.get(), B_col.get(), &ldb, &info);
273303
if (info != 0) return false;
274-
info = getri(101, n, A, n, ipiv.get());
275-
return info == 0;
304+
for (size_t i = 0; i < N; ++i)
305+
for (size_t j = 0; j < N; ++j)
306+
A[i*N + j] = B_col[j*N + i];
307+
return true;
276308
}
277309
278310
// Template dispatcher
@@ -290,7 +322,7 @@ template<> struct blas_ops<float> {
290322
static void gemvt(const float* B, const float* a, float* y,
291323
size_t K, size_t N) { blas_sgemv_t(B, a, y, K, N); }
292324
// A_inv[N×N] = inv(A[N×N]) — in-place, returns true on success
293-
static bool inv (float* A, size_t N) { return blas_sinv(A, N); }
325+
static bool inv (float* A, size_t N) { return blas_gesv_inv<float>(A, N); }
294326
};
295327
template<> struct blas_ops<double> {
296328
static double dot (const double* x, const double* y, size_t n) { return blas_ddot(x, y, n); }
@@ -301,7 +333,7 @@ template<> struct blas_ops<double> {
301333
size_t M, size_t K) { blas_dgemv(A, x, y, M, K); }
302334
static void gemvt(const double* B, const double* a, double* y,
303335
size_t K, size_t N) { blas_dgemv_t(B, a, y, K, N); }
304-
static bool inv (double* A, size_t N) { return blas_dinv(A, N); }
336+
static bool inv (double* A, size_t N) { return blas_gesv_inv<double>(A, N); }
305337
};
306338
307339
} // namespace detail

numpycpp/linalg.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,13 @@ inline void norm_axis(const T* src, T* dst,
9797
}
9898

9999
/// numpy.linalg.inv(a) — matrix inverse (square N×N)
100-
/// Uses LAPACKE getrf+getri (bitexact) or Gauss-Jordan (std backend).
100+
/// Uses DGESV (bitexact) or Gauss-Jordan (std backend).
101101
/// Returns true on success; false if matrix is singular or LAPACK unavailable.
102102
template<typename T>
103103
inline bool inv(const T* A, T* A_inv, size_t N) {
104104
// Copy input to output buffer (inv modifies in-place)
105105
for (size_t i = 0; i < N * N; ++i) A_inv[i] = A[i];
106106
bool ok = numpy::detail::blas_ops<T>::inv(A_inv, N);
107-
if (ok) {
108-
// Normalise -0.0 → +0.0 (LAPACK build variance in signed-zero output)
109-
for (size_t i = 0; i < N * N; ++i)
110-
if (A_inv[i] == T(0)) A_inv[i] = T(0);
111-
}
112107
return ok;
113108
}
114109

tests/test_all.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,30 +1257,15 @@ def test_inv_2x2(cpp):
12571257
a = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
12581258
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), "inv(2x2) f64")
12591259

1260-
def test_inv_random_correctness(cpp, dtype):
1261-
"""inv(A) @ A ≈ I for random 4×4."""
1260+
def test_inv_random(cpp, dtype):
1261+
"""inv(A) bit-identical to numpy for random 4×4."""
12621262
a = random_array((4, 4), dtype=dtype)
1263-
a_inv = cpp.linalg.inv(a)
1264-
# Verify A @ A_inv ≈ I
1265-
prod = cpp.matmul(a_inv, a)
1266-
eye = np.eye(4, dtype=dtype)
1267-
np.testing.assert_allclose(
1268-
np.asarray(prod), eye,
1269-
rtol=0, atol=dtype(2e-6),
1270-
err_msg=f"inv * A != I ({dtype.__name__})"
1271-
)
1272-
1273-
def test_inv_random_3x3_correctness(cpp, dtype):
1274-
"""inv(A) @ A ≈ I for random 3×3."""
1263+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), f"inv(4x4) {dtype.__name__}")
1264+
1265+
def test_inv_random_3x3(cpp, dtype):
1266+
"""inv(A) bit-identical to numpy for random 3×3."""
12751267
a = random_array((3, 3), dtype=dtype)
1276-
a_inv = cpp.linalg.inv(a)
1277-
prod = cpp.matmul(a_inv, a)
1278-
eye = np.eye(3, dtype=dtype)
1279-
np.testing.assert_allclose(
1280-
np.asarray(prod), eye,
1281-
rtol=0, atol=dtype(2e-6),
1282-
err_msg=f"inv(3x3) * A != I ({dtype.__name__})"
1283-
)
1268+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), f"inv(3x3) {dtype.__name__}")
12841269

12851270
def test_inv_singular(cpp):
12861271
"""inv(singular) — numpy raises LinAlgError."""
@@ -1289,17 +1274,10 @@ def test_inv_singular(cpp):
12891274
with _pytest.raises(RuntimeError):
12901275
cpp.linalg.inv(a)
12911276

1292-
def test_inv_random_8x8_correctness(cpp):
1293-
"""inv(A) @ A ≈ I for random 8×8 float64."""
1277+
def test_inv_random_8x8(cpp):
1278+
"""inv(A) bit-identical to numpy for random 8×8 float64."""
12941279
a = random_array((8, 8), dtype=np.float64, seed=42)
1295-
a_inv = cpp.linalg.inv(a)
1296-
prod = cpp.matmul(a_inv, a)
1297-
eye = np.eye(8, dtype=np.float64)
1298-
np.testing.assert_allclose(
1299-
np.asarray(prod), eye,
1300-
rtol=0, atol=1e-12,
1301-
err_msg="inv(8x8) * A != I (f64)"
1302-
)
1280+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), "inv(8x8) f64")
13031281

13041282
def test_dot(cpp, dtype):
13051283
a = random_array((5,), dtype=dtype)

0 commit comments

Comments
 (0)