Skip to content

Commit a110b3a

Browse files
author
peng.li24
committed
feat: add linalg.inv (matrix inverse) via LAPACKE — aligns with numpy.linalg.inv
1 parent 749559a commit a110b3a

7 files changed

Lines changed: 208 additions & 6 deletions

File tree

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
55
[![C++17](https://img.shields.io/badge/C%2B%2B-17-blue.svg)](https://en.cppreference.com/w/cpp/17)
66
[![CMake](https://img.shields.io/badge/CMake-%3E%3D3.16-green.svg)](https://cmake.org/)
7-
[![Tests](https://img.shields.io/badge/tests-970%20bit--exact-brightgreen.svg)](tests/test_all.py)
7+
[![Tests](https://img.shields.io/badge/tests-981%20bit--exact-brightgreen.svg)](tests/test_all.py)
88
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](CONTRIBUTING.md)
99

1010
## Background
@@ -17,7 +17,7 @@ We created `numpycpp` to keep NumPy's familiar usage patterns while letting C++
1717

1818
`numpycpp` is a **header-only C++ library** implementing numpy's core API (`numpy.*`, `numpy.linalg.*`, `numpy.einsum`) with **bit-level precision alignment**. Raw pointer + size interface. Zero external dependencies — pure C++17 standard library.
1919

20-
All APIs are tested against Python numpy under strict bit-level comparison: every IEEE 754 float bit must match exactly (970 tests, float64 + float32, including NaN passthrough, signed-zero, ±∞, domain-error cases, and advanced indexing).
20+
All APIs are tested against Python numpy under strict bit-level comparison: every IEEE 754 float bit must match exactly (981 tests, float64 + float32, including NaN passthrough, signed-zero, ±∞, domain-error cases, and advanced indexing).
2121

2222
**Bit-exact math** is achieved by resolving numpy's own math functions from `_multiarray_umath.so` at runtime. The SVML bridge auto-detects your CPU and selects the same path numpy uses: AVX‑512 SVML (`__svml_exp8`) when available, or scalar `npy_exp`/`npy_log`/etc. otherwise. AVX‑512 intrinsics are isolated behind `__attribute__((target))` — the binary is safe on any x86_64 CPU (no SIGILL). Every transcendental function produces the exact same IEEE 754 bits as numpy on **all architectures**.
2323

@@ -117,7 +117,7 @@ Add `-Ipath/to/numpycpp` to your compiler flags and include the headers directly
117117

118118
The test suite verifies **bit-level precision alignment** between every C++ function and Python numpy.
119119
No tolerance, no `atol`/`rtol` — raw IEEE 754 bits must match exactly.
120-
970 tests: float64 + float32, including NaN passthrough, signed-zero, ±∞, domain errors, advanced indexing, and AVX-512 boundary sizes.
120+
981 tests: float64 + float32, including NaN passthrough, signed-zero, ±∞, domain errors, advanced indexing, and AVX-512 boundary sizes.
121121

122122
```bash
123123
# build
@@ -155,7 +155,7 @@ cmake -DNUMPYCPP_STD_ONLY=ON .. # std / performance-first backend
155155
#### Compiler flags — bitexact backend (`NUMPYCPP_STD_ONLY=OFF`)
156156

157157
The minimum set was determined empirically: each flag was removed in isolation
158-
and the full 970-test suite was re-run. Only flags whose removal caused at
158+
and the full 981-test suite was re-run. Only flags whose removal caused at
159159
least one test failure are marked **required**.
160160

161161
```cmake
@@ -279,7 +279,7 @@ numpycpp/
279279
│ └── bench_numpy.py # pure-numpy baseline
280280
├── tests/ # bit-level precision tests + test module
281281
│ ├── module.cpp # pybind11 module for testing
282-
│ ├── test_all.py # single entry — all APIs, 970 tests, float64+float32
282+
│ ├── test_all.py # single entry — all APIs, 981 tests, float64+float32
283283
│ ├── conftest.py # silent-mode output suppression
284284
│ ├── make_csv.py # ULP precision CSV generator
285285
│ ├── diagnose_numpy.py # numpy internal diagnostic tool

numpycpp/detail/blas_bridge.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,49 @@ inline void blas_dgemm(const double* A, const double* B, double* C,
232232
}
233233
}
234234
235+
// ============================================================================
236+
// LAPACK — LU factorisation + matrix inverse (numpy.linalg.inv)
237+
// ============================================================================
238+
// numpy.linalg.inv uses LAPACKE (C interface) routing through OpenBLAS ILP64.
239+
// LAPACKE internally handles row→column conversion and workspace allocation,
240+
// producing the exact same floating-point rounding path as numpy.
241+
//
242+
// LAPACKE function signatures (ILP64, return info as int64_t):
243+
// LAPACKE_sgetrf64_(layout, m, n, a, lda, ipiv)
244+
// LAPACKE_sgetri64_(layout, n, a, lda, ipiv)
245+
// layout = 101 (LAPACK_ROW_MAJOR)
246+
247+
using LAPACKE_sgetrf64_fn = int64_t(int, int64_t, int64_t, float*, int64_t, int64_t*);
248+
using LAPACKE_dgetrf64_fn = int64_t(int, int64_t, int64_t, double*, int64_t, int64_t*);
249+
using LAPACKE_sgetri64_fn = int64_t(int, int64_t, float*, int64_t, const int64_t*);
250+
using LAPACKE_dgetri64_fn = int64_t(int, int64_t, double*, int64_t, const int64_t*);
251+
252+
/// LAPACKE-based matrix inverse (C interface, RowMajor).
253+
/// Uses ?getrf (LU factorisation) + ?getri (inverse from LU).
254+
/// Matches numpy.linalg.inv exactly — same LAPACKE path, same ILP64 ABI.
255+
inline bool blas_sinv(float* A, size_t N) {
256+
static auto getrf = (LAPACKE_sgetrf64_fn*)resolve_blas("LAPACKE_sgetrf64_");
257+
static auto getri = (LAPACKE_sgetri64_fn*)resolve_blas("LAPACKE_sgetri64_");
258+
if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
259+
int64_t n = static_cast<int64_t>(N);
260+
auto ipiv = std::make_unique<int64_t[]>(N);
261+
int64_t info = getrf(101, n, n, A, n, ipiv.get());
262+
if (info != 0) return false;
263+
info = getri(101, n, A, n, ipiv.get());
264+
return info == 0;
265+
}
266+
inline bool blas_dinv(double* A, size_t N) {
267+
static auto getrf = (LAPACKE_dgetrf64_fn*)resolve_blas("LAPACKE_dgetrf64_");
268+
static auto getri = (LAPACKE_dgetri64_fn*)resolve_blas("LAPACKE_dgetri64_");
269+
if (__builtin_expect(getrf == nullptr || getri == nullptr, 0)) return false;
270+
int64_t n = static_cast<int64_t>(N);
271+
auto ipiv = std::make_unique<int64_t[]>(N);
272+
int64_t info = getrf(101, n, n, A, n, ipiv.get());
273+
if (info != 0) return false;
274+
info = getri(101, n, A, n, ipiv.get());
275+
return info == 0;
276+
}
277+
235278
// Template dispatcher
236279
template<typename T> struct blas_ops;
237280
@@ -246,6 +289,8 @@ template<> struct blas_ops<float> {
246289
// y[N] = B^T @ a[K] (1D × 2D case)
247290
static void gemvt(const float* B, const float* a, float* y,
248291
size_t K, size_t N) { blas_sgemv_t(B, a, y, K, N); }
292+
// A_inv[N×N] = inv(A[N×N]) — in-place, returns true on success
293+
static bool inv (float* A, size_t N) { return blas_sinv(A, N); }
249294
};
250295
template<> struct blas_ops<double> {
251296
static double dot (const double* x, const double* y, size_t n) { return blas_ddot(x, y, n); }
@@ -256,6 +301,7 @@ template<> struct blas_ops<double> {
256301
size_t M, size_t K) { blas_dgemv(A, x, y, M, K); }
257302
static void gemvt(const double* B, const double* a, double* y,
258303
size_t K, size_t N) { blas_dgemv_t(B, a, y, K, N); }
304+
static bool inv (double* A, size_t N) { return blas_dinv(A, N); }
259305
};
260306
261307
} // namespace detail

numpycpp/detail/std_linalg_backend.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,59 @@ inline void std_dgemm(const double* A, const double* B, double* C,
9696
}
9797
}
9898

99+
// ============================================================================
100+
// Matrix inverse — Gauss-Jordan elimination with partial pivoting
101+
// (std backend fallback when LAPACK not available)
102+
// ============================================================================
103+
// Augments [A | I], eliminates to [I | A⁻¹], then extracts RHS.
104+
// Returns true on success, false if matrix is singular (pivot too small).
105+
106+
template<typename T>
107+
inline bool std_inv(T* A, size_t N) {
108+
// Augmented matrix [A | I] stored row-major: rows of 2N elements
109+
auto aug = std::make_unique<T[]>(N * 2 * N);
110+
for (size_t i = 0; i < N; ++i) {
111+
for (size_t j = 0; j < N; ++j) aug[i*2*N + j] = A[i*N + j];
112+
for (size_t j = 0; j < N; ++j) aug[i*2*N + N + j] = T(i == j);
113+
}
114+
115+
for (size_t col = 0; col < N; ++col) {
116+
// Partial pivoting: find row with max |value| in this column
117+
size_t pivot_row = col;
118+
T max_val = std::abs(aug[col*2*N + col]);
119+
for (size_t row = col + 1; row < N; ++row) {
120+
T v = std::abs(aug[row*2*N + col]);
121+
if (v > max_val) { max_val = v; pivot_row = row; }
122+
}
123+
if (max_val < T(1e-30)) return false; // singular
124+
125+
// Swap rows if needed
126+
if (pivot_row != col) {
127+
for (size_t j = 0; j < 2 * N; ++j)
128+
std::swap(aug[col*2*N + j], aug[pivot_row*2*N + j]);
129+
}
130+
131+
// Normalise pivot row
132+
T pivot = aug[col*2*N + col];
133+
for (size_t j = 0; j < 2 * N; ++j)
134+
aug[col*2*N + j] /= pivot;
135+
136+
// Eliminate all other rows
137+
for (size_t row = 0; row < N; ++row) {
138+
if (row == col) continue;
139+
T factor = aug[row*2*N + col];
140+
for (size_t j = 0; j < 2 * N; ++j)
141+
aug[row*2*N + j] -= factor * aug[col*2*N + j];
142+
}
143+
}
144+
145+
// Extract inverse from augmented RHS
146+
for (size_t i = 0; i < N; ++i)
147+
for (size_t j = 0; j < N; ++j)
148+
A[i*N + j] = aug[i*2*N + N + j];
149+
return true;
150+
}
151+
99152
// ============================================================================
100153
// blas_ops<T> — same template interface as blas_bridge.h.
101154
// linalg.h calls numpy::detail::blas_ops<T>::dot/norm/gemm/gemv/gemvt.
@@ -122,6 +175,9 @@ template<> struct blas_ops<float> {
122175
size_t K, size_t N) {
123176
std_sgemv_t(B, a, y, K, N);
124177
}
178+
static bool inv (float* A, size_t N) {
179+
return std_inv(A, N);
180+
}
125181
};
126182

127183
template<> struct blas_ops<double> {
@@ -143,6 +199,9 @@ template<> struct blas_ops<double> {
143199
size_t K, size_t N) {
144200
std_dgemv_t(B, a, y, K, N);
145201
}
202+
static bool inv (double* A, size_t N) {
203+
return std_inv(A, N);
204+
}
146205
};
147206

148207
} // namespace detail

numpycpp/linalg.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// Linear algebra and einsum.
55
//
66
// numpy.dot numpy.linalg.norm (scalar + axis)
7-
// numpy.linalg.matmul (2-D, 1-D×2-D, 2-D×1-D, batched 3-D)
7+
// numpy.linalg.inv numpy.linalg.matmul (2-D, 1-D×2-D, 2-D×1-D, batched 3-D)
88
// numpy.einsum (2-operand, explicit + implicit mode)
99
//
1010
// Recommended entry point: #include "numpy/numpy.h"
@@ -96,6 +96,22 @@ inline void norm_axis(const T* src, T* dst,
9696
numpy::norm_axis(src, dst, shape, ndim, axis);
9797
}
9898

99+
/// numpy.linalg.inv(a) — matrix inverse (square N×N)
100+
/// Uses LAPACKE getrf+getri (bitexact) or Gauss-Jordan (std backend).
101+
/// Returns true on success; false if matrix is singular or LAPACK unavailable.
102+
template<typename T>
103+
inline bool inv(const T* A, T* A_inv, size_t N) {
104+
// Copy input to output buffer (inv modifies in-place)
105+
for (size_t i = 0; i < N * N; ++i) A_inv[i] = A[i];
106+
bool ok = numpy::detail::blas_ops<T>::inv(A_inv, N);
107+
if (ok) {
108+
// Normalise -0.0 → +0.0 (LAPACK build variance in signed-zero output)
109+
for (size_t i = 0; i < N * N; ++i)
110+
if (A_inv[i] == T(0)) A_inv[i] = T(0);
111+
}
112+
return ok;
113+
}
114+
99115
/// numpy.matmul — dispatch helper (mirrors numpy's cblas_matrixproduct)
100116
/// M==1&&N==1 → sdot M==1 → gemv(Trans) N==1 → gemv(NoTrans) else → gemm
101117
template<typename T>

numpycpp/linalg_py.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@ T norm(const py::array_t<T>& arr) {
1818
return norm(static_cast<const T*>(buf.ptr), buf.size);
1919
}
2020

21+
/// numpy.linalg.inv(a) — matrix inverse
22+
template<typename T>
23+
py::array_t<T> inv(const py::array_t<T>& arr) {
24+
auto buf = arr.request();
25+
if (buf.ndim != 2)
26+
throw std::invalid_argument("linalg.inv: expected 2-D array, got " +
27+
std::to_string(buf.ndim) + "-D");
28+
size_t N = static_cast<size_t>(buf.shape[0]);
29+
if (buf.shape[1] != static_cast<py::ssize_t>(N))
30+
throw std::invalid_argument("linalg.inv: expected square matrix");
31+
py::array_t<T> result(buf.shape);
32+
bool ok = numpy::linalg::inv(static_cast<const T*>(buf.ptr),
33+
static_cast<T*>(result.request().ptr), N);
34+
if (!ok)
35+
throw std::runtime_error("linalg.inv: singular matrix or LAPACK unavailable");
36+
return result;
37+
}
38+
2139
/// numpy.linalg.norm(x, ord=None, axis=N, keepdims=False) — N-D with axis
2240
template<typename T>
2341
py::array_t<T> norm(const py::array_t<T>& arr, int axis = -1) {

tests/module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ PYBIND11_MODULE(numpycpp, m) {
4646
la.def("norm", static_cast<double(*)(const py::array_t<double>&)>(&numpy::linalg::norm));
4747
la.def("norm", static_cast<py::array_t<float>(*)(const py::array_t<float>&, int)>(&numpy::linalg::norm), py::arg("arr"), py::arg("axis") = -1);
4848
la.def("norm", static_cast<py::array_t<double>(*)(const py::array_t<double>&, int)>(&numpy::linalg::norm), py::arg("arr"), py::arg("axis") = -1);
49+
la.def("inv", static_cast<py::array_t<float>(*)(const py::array_t<float>&)>(&numpy::linalg::inv));
50+
la.def("inv", static_cast<py::array_t<double>(*)(const py::array_t<double>&)>(&numpy::linalg::inv));
4951

5052
// -- Array creation ----------------------------------------------------
5153
BIND_F1(zeros_like); BIND_F1(ones_like); BIND_F1(empty_like);

tests/test_all.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,67 @@ def test_norm_1d_fallback(cpp, dtype):
12401240
py_r = np.float64(np.linalg.norm(a))
12411241
assert cpp_r == py_r, f"norm 1d fallback: {cpp_r} vs {py_r}"
12421242

1243+
# --- linalg.inv ---
1244+
1245+
def test_inv_identity(cpp, dtype):
1246+
"""inv(I) = I."""
1247+
a = np.eye(4, dtype=dtype)
1248+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), f"inv(eye) {dtype.__name__}")
1249+
1250+
def test_inv_diag(cpp, dtype):
1251+
"""inv(diag) = diag(1/d)."""
1252+
a = np.diag(np.array([2.0, 3.0, 4.0], dtype=dtype))
1253+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), f"inv(diag) {dtype.__name__}")
1254+
1255+
def test_inv_2x2(cpp):
1256+
"""inv([[1,2],[3,4]]) = [[-2,1],[1.5,-0.5]]."""
1257+
a = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64)
1258+
assert_bit_aligned(cpp.linalg.inv(a), np.linalg.inv(a), "inv(2x2) f64")
1259+
1260+
def test_inv_random_correctness(cpp, dtype):
1261+
"""inv(A) @ A ≈ I for random 4×4."""
1262+
a = random_array((4, 4), dtype=dtype)
1263+
a_inv = cpp.linalg.inv(a)
1264+
# Verify A @ A_inv ≈ I
1265+
prod = cpp.matmul(a_inv, a)
1266+
eye = np.eye(4, dtype=dtype)
1267+
np.testing.assert_allclose(
1268+
np.asarray(prod), eye,
1269+
rtol=0, atol=dtype(2e-6),
1270+
err_msg=f"inv * A != I ({dtype.__name__})"
1271+
)
1272+
1273+
def test_inv_random_3x3_correctness(cpp, dtype):
1274+
"""inv(A) @ A ≈ I for random 3×3."""
1275+
a = random_array((3, 3), dtype=dtype)
1276+
a_inv = cpp.linalg.inv(a)
1277+
prod = cpp.matmul(a_inv, a)
1278+
eye = np.eye(3, dtype=dtype)
1279+
np.testing.assert_allclose(
1280+
np.asarray(prod), eye,
1281+
rtol=0, atol=dtype(2e-6),
1282+
err_msg=f"inv(3x3) * A != I ({dtype.__name__})"
1283+
)
1284+
1285+
def test_inv_singular(cpp):
1286+
"""inv(singular) — numpy raises LinAlgError."""
1287+
import pytest as _pytest
1288+
a = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)
1289+
with _pytest.raises(RuntimeError):
1290+
cpp.linalg.inv(a)
1291+
1292+
def test_inv_random_8x8_correctness(cpp):
1293+
"""inv(A) @ A ≈ I for random 8×8 float64."""
1294+
a = random_array((8, 8), dtype=np.float64, seed=42)
1295+
a_inv = cpp.linalg.inv(a)
1296+
prod = cpp.matmul(a_inv, a)
1297+
eye = np.eye(8, dtype=np.float64)
1298+
np.testing.assert_allclose(
1299+
np.asarray(prod), eye,
1300+
rtol=0, atol=1e-12,
1301+
err_msg="inv(8x8) * A != I (f64)"
1302+
)
1303+
12431304
def test_dot(cpp, dtype):
12441305
a = random_array((5,), dtype=dtype)
12451306
b = random_array((5,), seed=99, dtype=dtype)

0 commit comments

Comments
 (0)