Skip to content

Commit 95986fd

Browse files
author
peng.li24
committed
feat: expand astype dtype support — float32, int64, bool bidirectional
Supported conversions (5 dtypes × 4 targets = 20 combos): float64 ↔ float32, int32, int64, bool float32 → float64, int32, int64, bool int32 → float64, float32, int64, bool int64 → float64, float32, int32, bool bool → float64, float32, int32, int64 Add 7 new tests: f64→f32, f32→f64, f64→int64, int→f64, int→f32, bool→f64, bool→int Test count: 468 → 475
1 parent 0fee119 commit 95986fd

4 files changed

Lines changed: 151 additions & 20 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88
branches: [master]
99

1010
jobs:
11-
# ---- Test: build module + run 468 precision tests --------------------------
11+
# ---- Test: build module + run 475 precision tests --------------------------
1212
test:
1313
runs-on: ubuntu-22.04
1414
steps:

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ We created `numpycpp` to keep NumPy's familiar usage patterns while letting C++
1515

1616
`numpycpp` is a **header-only C++ library** implementing numpy's core API (`numpy.*`, `numpy.linalg.*`, `numpy.einsum`) with **bit-level precision alignment**. Raw pointer + size interface. Zero external dependencies — pure C++17 standard library.
1717

18-
All APIs are tested against Python numpy under strict bit-level comparison: every IEEE 754 float bit must match exactly (468 tests, float64 + float32).
18+
All APIs are tested against Python numpy under strict bit-level comparison: every IEEE 754 float bit must match exactly (475 tests, float64 + float32).
1919

2020
**Bit-exact math** is achieved by resolving numpy's own math functions from `_multiarray_umath.so` at runtime. The SVML bridge auto-detects your CPU and selects the same path numpy uses: AVX‑512 SVML (`__svml_exp8`) when available, or scalar `npy_exp`/`npy_log`/etc. otherwise. AVX‑512 intrinsics are isolated behind `__attribute__((target))` — the binary is safe on any x86_64 CPU (no SIGILL). Every transcendental function produces the exact same IEEE 754 bits as numpy on **all architectures**.
2121

@@ -89,12 +89,12 @@ Add `-Ipath/to/numpycpp` to your compiler flags and include the headers directly
8989
### Testing
9090

9191
The test suite verifies **bit-level precision alignment** between every C++ function and Python numpy.
92-
No tolerance, no `atol`/`rtol` — raw IEEE 754 bits must match exactly. 468 tests, float64 + float32.
92+
No tolerance, no `atol`/`rtol` — raw IEEE 754 bits must match exactly. 475 tests, float64 + float32.
9393

9494
```bash
9595
cd tests
9696
make # compile C++ test module
97-
make test # run all 468 tests (silent mode: only failures print)
97+
make test # run all 475 tests (silent mode: only failures print)
9898
```
9999

100100
To run with verbose output:
@@ -142,7 +142,7 @@ LDFLAGS = -shared -ldl
142142
### Alignment status
143143

144144
The table below reflects the current bit-level parity between `numpycpp` C++ and Python numpy.
145-
All 468 tests pass under strict IEEE 754 bit comparison (float64 + float32).
145+
All 475 tests pass under strict IEEE 754 bit comparison (float64 + float32).
146146

147147
✅ = bit-exact on ALL architectures (SVML bridge with runtime CPU dispatch).
148148

@@ -189,7 +189,7 @@ numpycpp/
189189
│ └── einsum_py.h
190190
├── tests/ # bit-level precision tests + test module
191191
│ ├── module.cpp # pybind11 module for testing
192-
│ ├── test_all.py # single entry — all APIs, 468 tests, float64+float32
192+
│ ├── test_all.py # single entry — all APIs, 475 tests, float64+float32
193193
│ ├── conftest.py # silent-mode output suppression
194194
│ └── Makefile
195195
├── CMakeLists.txt # build & .deb packaging

pycpp/core_py.h

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "../numpy/core.h"
1212
#include <vector>
1313
#include <cstring>
14+
#include <cstdint>
1415

1516
namespace py = pybind11;
1617

@@ -116,30 +117,132 @@ inline py::array ones_like(const py::array& arr, const std::string& dtype) {
116117
inline py::array astype(const py::array& arr, const std::string& dtype) {
117118
auto buf = arr.request();
118119
auto dt = arr.dtype();
120+
119121
// float64 input
120122
if (dt.is(py::dtype::of<double>())) {
121-
if (dtype == "int" || dtype == "int32" || dtype == "int64") {
122-
py::array_t<int> result(buf.shape);
123-
astype<int, double>(static_cast<const double*>(buf.ptr),
124-
static_cast<int*>(result.request().ptr), buf.size);
125-
return result;
123+
auto* src = static_cast<const double*>(buf.ptr);
124+
if (dtype == "float32" || dtype == "float") {
125+
py::array_t<float> r(buf.shape);
126+
astype<float, double>(src, static_cast<float*>(r.request().ptr), buf.size);
127+
return r;
128+
}
129+
if (dtype == "int" || dtype == "int32") {
130+
py::array_t<int> r(buf.shape);
131+
astype<int, double>(src, static_cast<int*>(r.request().ptr), buf.size);
132+
return r;
133+
}
134+
if (dtype == "int64") {
135+
py::array_t<int64_t> r(buf.shape);
136+
astype<int64_t, double>(src, static_cast<int64_t*>(r.request().ptr), buf.size);
137+
return r;
138+
}
139+
if (dtype == "bool") {
140+
py::array_t<bool> r(buf.shape);
141+
astype<bool, double>(src, static_cast<bool*>(r.request().ptr), buf.size);
142+
return r;
143+
}
144+
}
145+
146+
// float32 input
147+
if (dt.is(py::dtype::of<float>())) {
148+
auto* src = static_cast<const float*>(buf.ptr);
149+
if (dtype == "float64" || dtype == "double") {
150+
py::array_t<double> r(buf.shape);
151+
astype<double, float>(src, static_cast<double*>(r.request().ptr), buf.size);
152+
return r;
153+
}
154+
if (dtype == "int" || dtype == "int32") {
155+
py::array_t<int> r(buf.shape);
156+
astype<int, float>(src, static_cast<int*>(r.request().ptr), buf.size);
157+
return r;
158+
}
159+
if (dtype == "int64") {
160+
py::array_t<int64_t> r(buf.shape);
161+
astype<int64_t, float>(src, static_cast<int64_t*>(r.request().ptr), buf.size);
162+
return r;
126163
}
127164
if (dtype == "bool") {
128-
py::array_t<bool> result(buf.shape);
129-
astype<bool, double>(static_cast<const double*>(buf.ptr),
130-
static_cast<bool*>(result.request().ptr), buf.size);
131-
return result;
165+
py::array_t<bool> r(buf.shape);
166+
astype<bool, float>(src, static_cast<bool*>(r.request().ptr), buf.size);
167+
return r;
132168
}
133169
}
134-
// int input
170+
171+
// int32 input
135172
if (dt.is(py::dtype::of<int>())) {
173+
auto* src = static_cast<const int*>(buf.ptr);
174+
if (dtype == "float64" || dtype == "double") {
175+
py::array_t<double> r(buf.shape);
176+
astype<double, int>(src, static_cast<double*>(r.request().ptr), buf.size);
177+
return r;
178+
}
179+
if (dtype == "float32" || dtype == "float") {
180+
py::array_t<float> r(buf.shape);
181+
astype<float, int>(src, static_cast<float*>(r.request().ptr), buf.size);
182+
return r;
183+
}
184+
if (dtype == "int64") {
185+
py::array_t<int64_t> r(buf.shape);
186+
astype<int64_t, int>(src, static_cast<int64_t*>(r.request().ptr), buf.size);
187+
return r;
188+
}
136189
if (dtype == "bool") {
137-
py::array_t<bool> result(buf.shape);
138-
astype<bool, int>(static_cast<const int*>(buf.ptr),
139-
static_cast<bool*>(result.request().ptr), buf.size);
140-
return result;
190+
py::array_t<bool> r(buf.shape);
191+
astype<bool, int>(src, static_cast<bool*>(r.request().ptr), buf.size);
192+
return r;
141193
}
142194
}
195+
196+
// int64 input
197+
if (dt.is(py::dtype::of<int64_t>())) {
198+
auto* src = static_cast<const int64_t*>(buf.ptr);
199+
if (dtype == "float64" || dtype == "double") {
200+
py::array_t<double> r(buf.shape);
201+
astype<double, int64_t>(src, static_cast<double*>(r.request().ptr), buf.size);
202+
return r;
203+
}
204+
if (dtype == "float32" || dtype == "float") {
205+
py::array_t<float> r(buf.shape);
206+
astype<float, int64_t>(src, static_cast<float*>(r.request().ptr), buf.size);
207+
return r;
208+
}
209+
if (dtype == "int" || dtype == "int32") {
210+
py::array_t<int> r(buf.shape);
211+
astype<int, int64_t>(src, static_cast<int*>(r.request().ptr), buf.size);
212+
return r;
213+
}
214+
if (dtype == "bool") {
215+
py::array_t<bool> r(buf.shape);
216+
astype<bool, int64_t>(src, static_cast<bool*>(r.request().ptr), buf.size);
217+
return r;
218+
}
219+
}
220+
221+
// bool input
222+
if (dt.is(py::dtype::of<bool>())) {
223+
auto* src = static_cast<const bool*>(buf.ptr);
224+
if (dtype == "float64" || dtype == "double") {
225+
py::array_t<double> r(buf.shape);
226+
astype<double, bool>(src, static_cast<double*>(r.request().ptr), buf.size);
227+
return r;
228+
}
229+
if (dtype == "float32" || dtype == "float") {
230+
py::array_t<float> r(buf.shape);
231+
astype<float, bool>(src, static_cast<float*>(r.request().ptr), buf.size);
232+
return r;
233+
}
234+
if (dtype == "int" || dtype == "int32") {
235+
py::array_t<int> r(buf.shape);
236+
astype<int, bool>(src, static_cast<int*>(r.request().ptr), buf.size);
237+
return r;
238+
}
239+
if (dtype == "int64") {
240+
py::array_t<int64_t> r(buf.shape);
241+
astype<int64_t, bool>(src, static_cast<int64_t*>(r.request().ptr), buf.size);
242+
return r;
243+
}
244+
}
245+
143246
throw std::runtime_error("astype: unsupported conversion " + std::string(py::str(dt)) + " -> " + dtype);
144247
}
145248

tests/test_all.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,34 @@ def test_astype_bool_from_int(cpp):
471471
a = np.array([[0, 1, -1], [42, 0, 0]], dtype=np.int32)
472472
assert_bit_aligned(cpp.astype(a, "bool"), a.astype(bool), "astype_bool_from_int")
473473

474+
def test_astype_f64_to_f32(cpp):
475+
a = np.array([1.5, 2.7, -3.1], dtype=np.float64)
476+
assert_bit_aligned(cpp.astype(a, "float32"), a.astype(np.float32), "astype_f64_to_f32")
477+
478+
def test_astype_f32_to_f64(cpp):
479+
a = np.array([1.5, 2.7, -3.1], dtype=np.float32)
480+
assert_bit_aligned(cpp.astype(a, "float64"), a.astype(np.float64), "astype_f32_to_f64")
481+
482+
def test_astype_f64_to_int64(cpp):
483+
a = np.array([1.5, 2.7, -3.1], dtype=np.float64)
484+
assert_bit_aligned(cpp.astype(a, "int64"), a.astype(np.int64), "astype_f64_to_int64")
485+
486+
def test_astype_int_to_f64(cpp):
487+
a = np.array([1, 2, -3], dtype=np.int32)
488+
assert_bit_aligned(cpp.astype(a, "float64"), a.astype(np.float64), "astype_int_to_f64")
489+
490+
def test_astype_int_to_f32(cpp):
491+
a = np.array([1, 2, -3], dtype=np.int32)
492+
assert_bit_aligned(cpp.astype(a, "float32"), a.astype(np.float32), "astype_int_to_f32")
493+
494+
def test_astype_bool_to_f64(cpp):
495+
a = np.array([True, False, True], dtype=bool)
496+
assert_bit_aligned(cpp.astype(a, "float64"), a.astype(np.float64), "astype_bool_to_f64")
497+
498+
def test_astype_bool_to_int(cpp):
499+
a = np.array([True, False, True, False], dtype=bool)
500+
assert_bit_aligned(cpp.astype(a, "int"), a.astype(np.int32), "astype_bool_to_int")
501+
474502
def test_truncate_to_float32(cpp):
475503
a = np.array([1.0 / 3.0, np.pi, np.sqrt(2.0)], dtype=np.float64)
476504
py_r = a.astype(np.float32).astype(np.float64)

0 commit comments

Comments
 (0)