Skip to content

Commit 1b48797

Browse files
peng.li24claude
andcommitted
feat: add precision alignment test framework for all C++ APIs
Design a comprehensive pytest-based test framework under tests/ that verifies every numpcpp C++ function against its Python numpy counterpart: - conftest.py: shared comparison engine, lazy C++ module import, CLI options (--cpp-module, --rtol, --atol), random data generators - test_core.py: 70+ test cases covering all core.h functions — array creation, math, reduction, comparison, logical, array manipulation, statistical, set ops, interpolation, sorting - test_linalg.py: norm (float64/float32), norm_axis1, dot - test_einsum.py: all 8 einsum patterns (explicit + implicit), edge cases, realistic gate_machine workloads - pyproject.toml: pytest configuration and test dependencies Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent fe86662 commit 1b48797

7 files changed

Lines changed: 1387 additions & 346 deletions

File tree

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ make -j$(nproc)
3434
import numpcpp
3535
```
3636

37+
### Testing
38+
39+
The test suite verifies that every C++ function produces **pixel-level identical** results to Python numpy.
40+
41+
```bash
42+
# Install test dependencies
43+
pip install numpy pytest
44+
45+
# Run all precision alignment tests
46+
pytest tests/ -v
47+
48+
# Custom tolerances
49+
pytest tests/ --rtol=1e-10 --atol=1e-10
50+
51+
# Custom C++ module name
52+
pytest tests/ --cpp-module=my_numpycpp_module
53+
```
54+
55+
All tests pass only when C++ output matches Python numpy output within the specified tolerance (default: `rtol=1e-12, atol=1e-12`).
56+
3757
## Project Structure
3858

3959
```

pyproject.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
[build-system]
2+
requires = ["setuptools>=64", "pybind11>=2.10", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "numpycpp"
7+
version = "0.1.0"
8+
description = "C++ pixel-level alignment of Python numpy, powered by Eigen"
9+
requires-python = ">=3.8"
10+
license = {text = "MIT"}
11+
12+
[project.optional-dependencies]
13+
test = [
14+
"numpy>=1.20",
15+
"pytest>=7.0",
16+
]
17+
18+
[tool.pytest.ini_options]
19+
testpaths = ["tests"]
20+
python_files = ["test_*.py"]
21+
python_classes = ["Test*"]
22+
python_functions = ["test_*"]
23+
addopts = ["-v", "--tb=short", "--strict-markers"]
24+
markers = [
25+
"slow: slow tests (deselect with '-m \"not slow\"')",
26+
]

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Precision alignment test framework for numpcpp vs Python numpy.
3+
4+
Design:
5+
Each test calls a C++ function (via the compiled numpcpp module) and the
6+
equivalent Python numpy function with identical inputs, then compares results
7+
using configurable tolerances.
8+
9+
The C++ module name is configurable via --cpp-module or NUMPYCPP_MODULE env var.
10+
Default: "numpycpp".
11+
12+
Usage:
13+
pytest tests/ # run all tests
14+
pytest tests/ --cpp-module=my_module # custom C++ module name
15+
pytest tests/ -v --tb=short # verbose, short tracebacks
16+
pytest tests/ --rtol=1e-10 --atol=1e-10 # custom tolerances
17+
"""
18+
19+
import numpy as np
20+
import pytest
21+
22+
23+
def pytest_addoption(parser):
24+
parser.addoption(
25+
"--cpp-module",
26+
action="store",
27+
default=None,
28+
help="Python module name for the compiled C++ numpcpp library (default: numpcpp)",
29+
)
30+
parser.addoption(
31+
"--rtol",
32+
action="store",
33+
type=float,
34+
default=1e-12,
35+
help="Relative tolerance for numerical comparisons (default: 1e-12)",
36+
)
37+
parser.addoption(
38+
"--atol",
39+
action="store",
40+
type=float,
41+
default=1e-12,
42+
help="Absolute tolerance for numerical comparisons (default: 1e-12)",
43+
)
44+
45+
46+
# ---------------------------------------------------------------------------
47+
# Shared state: lazily-imported C++ module
48+
# ---------------------------------------------------------------------------
49+
_cpp_module = None
50+
_import_error = None
51+
52+
53+
def _resolve_module_name(config) -> str:
54+
"""Resolve C++ module name from CLI, env, or default."""
55+
import os
56+
57+
cli = config.getoption("--cpp-module", default=None)
58+
if cli:
59+
return cli
60+
env = os.environ.get("NUMPYCPP_MODULE")
61+
if env:
62+
return env
63+
return "numpycpp"
64+
65+
66+
def get_cpp_module(request=None) -> "module":
67+
"""
68+
Return the compiled numpcpp C++ module. Import is attempted once and
69+
cached; if it fails, all dependent tests are skipped with a clear message.
70+
"""
71+
global _cpp_module, _import_error
72+
73+
if _cpp_module is not None:
74+
return _cpp_module
75+
if _import_error is not None:
76+
raise _import_error
77+
78+
if request is not None:
79+
modname = _resolve_module_name(request.config)
80+
else:
81+
import os
82+
83+
modname = os.environ.get("NUMPYCPP_MODULE", "numpycpp")
84+
85+
import importlib
86+
87+
try:
88+
_cpp_module = importlib.import_module(modname)
89+
except ImportError as e:
90+
_import_error = e
91+
raise
92+
return _cpp_module
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# Comparison engine
97+
# ---------------------------------------------------------------------------
98+
99+
100+
def compare(
101+
cpp_result,
102+
py_result,
103+
rtol: float = 1e-12,
104+
atol: float = 1e-12,
105+
label: str = "",
106+
):
107+
"""
108+
Compare C++ result against Python (ground-truth) result.
109+
110+
Returns a dict with keys: pass, max_abs_diff, max_rel_diff, shape_match,
111+
cpp_dtype, py_dtype, label.
112+
113+
Does NOT raise on mismatch — returns structured diagnostics.
114+
"""
115+
cpp = np.asarray(cpp_result, dtype=np.float64)
116+
py = np.asarray(py_result, dtype=np.float64)
117+
118+
info = {
119+
"label": label,
120+
"shape_match": cpp.shape == py.shape,
121+
"cpp_shape": cpp.shape,
122+
"py_shape": py.shape,
123+
"cpp_dtype": str(cpp.dtype),
124+
"py_dtype": str(py.dtype),
125+
}
126+
127+
if not info["shape_match"]:
128+
info["pass"] = False
129+
info["max_abs_diff"] = float("nan")
130+
info["max_rel_diff"] = float("nan")
131+
info["error"] = f"shape mismatch: C++ {cpp.shape} vs Python {py.shape}"
132+
return info
133+
134+
abs_diff = np.abs(cpp - py)
135+
max_abs = float(np.max(abs_diff))
136+
137+
# Relative diff: avoid division by zero
138+
py_abs = np.abs(py)
139+
with np.errstate(divide="ignore", invalid="ignore"):
140+
rel_diff = np.where(py_abs > 0, abs_diff / py_abs, abs_diff)
141+
max_rel = float(np.max(rel_diff))
142+
143+
passed = bool(np.allclose(cpp, py, rtol=rtol, atol=atol))
144+
145+
info["pass"] = passed
146+
info["max_abs_diff"] = max_abs
147+
info["max_rel_diff"] = max_rel
148+
149+
if not passed:
150+
# Find worst offenders for diagnostics
151+
worst_idx = int(np.argmax(abs_diff))
152+
info["error"] = (
153+
f"numerical mismatch: max_abs_diff={max_abs:.2e}, "
154+
f"max_rel_diff={max_rel:.2e} at idx {worst_idx}\n"
155+
f" C++ value: {cpp.flat[worst_idx]:.16e}\n"
156+
f" Py value: {py.flat[worst_idx]:.16e}"
157+
)
158+
159+
return info
160+
161+
162+
def assert_match(
163+
cpp_result,
164+
py_result,
165+
rtol: float = 1e-12,
166+
atol: float = 1e-12,
167+
label: str = "",
168+
):
169+
"""
170+
Like compare(), but raises AssertionError on mismatch (for use in plain
171+
unittest-style tests).
172+
"""
173+
info = compare(cpp_result, py_result, rtol=rtol, atol=atol, label=label)
174+
if not info["pass"]:
175+
raise AssertionError(info.get("error", "mismatch"))
176+
return info
177+
178+
179+
# ---------------------------------------------------------------------------
180+
# Shared test-data generators
181+
# ---------------------------------------------------------------------------
182+
183+
184+
def random_array(shape, dtype=np.float64, seed: int = 42):
185+
"""Deterministic random array with controlled seed per shape."""
186+
rng = np.random.RandomState(seed + hash(shape) % (2**31))
187+
if np.issubdtype(dtype, np.floating):
188+
return rng.randn(*shape).astype(dtype)
189+
elif dtype == bool:
190+
return rng.rand(*shape) > 0.5
191+
else:
192+
return rng.randint(0, 100, size=shape).astype(dtype)
193+
194+
195+
# ---------------------------------------------------------------------------
196+
# Fixtures
197+
# ---------------------------------------------------------------------------
198+
199+
200+
@pytest.fixture(scope="session")
201+
def cpp():
202+
"""Session-scoped C++ module fixture."""
203+
return get_cpp_module()
204+
205+
206+
@pytest.fixture
207+
def rtol(request):
208+
return request.config.getoption("--rtol", default=1e-12)
209+
210+
211+
@pytest.fixture
212+
def atol(request):
213+
return request.config.getoption("--atol", default=1e-12)

0 commit comments

Comments
 (0)