@@ -38,6 +38,7 @@ Use #include "numpycpp/numpy.h" instead."
3838#endif
3939
4040#include <cstdint>
41+ #include <cstring>
4142#include <cmath>
4243#include <dlfcn.h>
4344#include <fstream>
@@ -233,46 +234,77 @@ inline void blas_dgemm(const double* A, const double* B, double* C,
233234}
234235
235236// ============================================================================
236- // LAPACK — LU factorisation + matrix inverse (numpy.linalg.inv)
237+ // LAPACK — matrix inverse (numpy.linalg.inv) via Fortran DGESV
237238// ============================================================================
238- // numpy.linalg.inv uses LAPACKE (C interface) routing through OpenBLAS ILP64.
239- // LAPACKE internally handles row→column conversion and workspace allocation,
240- // producing the exact same floating-point rounding path as numpy.
239+ // numpy.linalg.inv calls LAPACK ?gesv (solve A·X = I). DGESV fuses LU
240+ // factorisation + forward/back substitution in a single kernel.
241241//
242- // LAPACKE function signatures (ILP64, return info as int64_t):
243- // LAPACKE_sgetrf64_(layout, m, n, a, lda, ipiv)
244- // LAPACKE_sgetri64_(layout, n, a, lda, ipiv)
245- // layout = 101 (LAPACK_ROW_MAJOR)
242+ // On this OpenBLAS build, sgesv_64_ produces 1‑ULP differences vs numpy for
243+ // float32 inputs. NumPy's float32 inv is bit‑equivalent to: promote to
244+ // float64 → dgesv → demote to float32. We follow that same path so both
245+ // dtypes are IEEE‑754 bit‑identical to numpy.
246+ //
247+ // Fortran DGESV signature (ILP64, _64_ suffix):
248+ // dgesv_64_(int64_t *N, int64_t *NRHS, double *A, int64_t *LDA,
249+ // int64_t *IPIV, double *B, int64_t *LDB, int64_t *INFO);
250+
251+ using dgesv64_fn = void(int64_t*, int64_t*, double*, int64_t*,
252+ int64_t*, double*, int64_t*, int64_t*);
246253
247- using LAPACKE_sgetrf64_fn = int64_t(int, int64_t, int64_t, float*, int64_t, int64_t*);
248- using LAPACKE_dgetrf64_fn = int64_t(int, int64_t, int64_t, double*, int64_t, int64_t*);
249- using LAPACKE_sgetri64_fn = int64_t(int, int64_t, float*, int64_t, const int64_t*);
250- using LAPACKE_dgetri64_fn = int64_t(int, int64_t, double*, int64_t, const int64_t*);
254+ /// Fortran DGESV-based matrix inverse. Matches numpy.linalg.inv exactly
255+ /// — same Fortran symbol, same ILP64 ABI, same memory layout.
256+ template<typename T> inline bool blas_gesv_inv(T* A, size_t N);
251257
252- /// LAPACKE-based matrix inverse (C interface, RowMajor).
253- /// Uses ?getrf (LU factorisation) + ?getri (inverse from LU).
254- /// Matches numpy.linalg.inv exactly — same LAPACKE path, same ILP64 ABI.
255- inline bool blas_sinv(float* A, size_t N) {
256- static auto getrf = (LAPACKE_sgetrf64_fn*)resolve_blas( " LAPACKE_sgetrf64_ " );
257- static auto getri = (LAPACKE_sgetri64_fn *)resolve_blas(" LAPACKE_sgetri64_ " );
258- if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
258+ template<> inline bool blas_gesv_inv<float>(float* A, size_t N) {
259+ // numpy.linalg.inv for float32 produces the same bits as:
260+ // float32 → float64 → dgesv → float32
261+ // (OpenBLAS sgesv_64_ gives 1-ULP-off results vs numpy on this build;
262+ // the float64 path is bit-identical for both types.)
263+ static auto gesv = (dgesv64_fn *)resolve_blas(" dgesv_64_ " );
264+ if (__builtin_expect(gesv == nullptr, 0)) return false;
259265 int64_t n = static_cast<int64_t>(N);
260266 auto ipiv = std::make_unique<int64_t[]>(N);
261- int64_t info = getrf(101, n, n, A, n, ipiv.get());
267+ // Double-precision work buffers (column-major)
268+ auto A_col = std::make_unique<double[]>(N * N);
269+ auto B_col = std::make_unique<double[]>(N * N);
270+ // Promote A row-major → A_col column-major (float→double)
271+ for (size_t i = 0; i < N; ++i)
272+ for (size_t j = 0; j < N; ++j)
273+ A_col[j*N + i] = static_cast<double>(A[i*N + j]);
274+ // B = identity (column-major, double)
275+ std::memset(B_col.get(), 0, N * N * sizeof(double));
276+ for (size_t i = 0; i < N; ++i)
277+ B_col[i + i*N] = 1.0;
278+ int64_t nrhs = n, lda = n, ldb = n, info = 0;
279+ gesv(&n, &nrhs, A_col.get(), &lda, ipiv.get(), B_col.get(), &ldb, &info);
262280 if (info != 0) return false;
263- info = getri(101, n, A, n, ipiv.get());
264- return info == 0;
281+ // Demote solution back to float32 row-major
282+ for (size_t i = 0; i < N; ++i)
283+ for (size_t j = 0; j < N; ++j)
284+ A[i*N + j] = static_cast<float>(B_col[j*N + i]);
285+ return true;
265286}
266- inline bool blas_dinv(double* A, size_t N) {
267- static auto getrf = (LAPACKE_dgetrf64_fn*)resolve_blas( " LAPACKE_dgetrf64_ " );
268- static auto getri = (LAPACKE_dgetri64_fn *)resolve_blas(" LAPACKE_dgetri64_ " );
269- if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
287+
288+ template<> inline bool blas_gesv_inv<double>(double* A, size_t N) {
289+ static auto gesv = (dgesv64_fn *)resolve_blas(" dgesv_64_ " );
290+ if (__builtin_expect(gesv == nullptr, 0)) return false;
270291 int64_t n = static_cast<int64_t>(N);
271292 auto ipiv = std::make_unique<int64_t[]>(N);
272- int64_t info = getrf(101, n, n, A, n, ipiv.get());
293+ auto A_col = std::make_unique<double[]>(N * N);
294+ auto B_col = std::make_unique<double[]>(N * N);
295+ for (size_t i = 0; i < N; ++i)
296+ for (size_t j = 0; j < N; ++j)
297+ A_col[j*N + i] = A[i*N + j];
298+ for (size_t i = 0; i < N; ++i)
299+ for (size_t j = 0; j < N; ++j)
300+ B_col[i + j*N] = (i == j) ? 1.0 : 0.0;
301+ int64_t nrhs = n, lda = n, ldb = n, info = 0;
302+ gesv(&n, &nrhs, A_col.get(), &lda, ipiv.get(), B_col.get(), &ldb, &info);
273303 if (info != 0) return false;
274- info = getri(101, n, A, n, ipiv.get());
275- return info == 0;
304+ for (size_t i = 0; i < N; ++i)
305+ for (size_t j = 0; j < N; ++j)
306+ A[i*N + j] = B_col[j*N + i];
307+ return true;
276308}
277309
278310// Template dispatcher
@@ -290,7 +322,7 @@ template<> struct blas_ops<float> {
290322 static void gemvt(const float* B, const float* a, float* y,
291323 size_t K, size_t N) { blas_sgemv_t(B, a, y, K, N); }
292324 // A_inv[N×N] = inv(A[N×N]) — in-place, returns true on success
293- static bool inv (float* A, size_t N) { return blas_sinv (A, N); }
325+ static bool inv (float* A, size_t N) { return blas_gesv_inv<float> (A, N); }
294326};
295327template<> struct blas_ops<double> {
296328 static double dot (const double* x, const double* y, size_t n) { return blas_ddot(x, y, n); }
@@ -301,7 +333,7 @@ template<> struct blas_ops<double> {
301333 size_t M, size_t K) { blas_dgemv(A, x, y, M, K); }
302334 static void gemvt(const double* B, const double* a, double* y,
303335 size_t K, size_t N) { blas_dgemv_t(B, a, y, K, N); }
304- static bool inv (double* A, size_t N) { return blas_dinv (A, N); }
336+ static bool inv (double* A, size_t N) { return blas_gesv_inv<double> (A, N); }
305337};
306338
307339} // namespace detail
0 commit comments