Skip to content

Commit c9382d7

Browse files
author
peng.li24
committed
fix: template cumsum/squeeze/unwrap/intersect1d to preserve input dtype
Convert non-template pycpp wrappers to template<T> to match numpy's dtype-preserving behavior: cumsum(f32)→f32, squeeze(f32)→f32, unwrap(f32)→f32, intersect1d(f32)→f32 Add float32 pybind11 bindings for all four functions. Add float32 test coverage (loop over [float64, float32]). Known limitation: float32 unwrap with large π-multiple values differs from numpy by ~ULP because numpy uses float64 constants internally (period = 2*np.float64(pi)) even for float32 input.
1 parent a9b4fe9 commit c9382d7

3 files changed

Lines changed: 48 additions & 33 deletions

File tree

pycpp/core_py.h

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,42 +1018,46 @@ inline py::array_t<py::ssize_t> flatnonzero(const py::array_t<double>& arr) {
10181018
}
10191019

10201020
/// numpy.unwrap(p, discont=None, axis=-1) — 1D only
1021-
inline py::array_t<double> unwrap(const py::array_t<double>& arr, double discont = M_PI) {
1021+
template<typename T>
1022+
py::array_t<T> unwrap(const py::array_t<T>& arr, T discont = T(M_PI)) {
10221023
auto buf = arr.request();
1023-
py::array_t<double> result(buf.shape);
1024-
numpy::unwrap(static_cast<const double*>(buf.ptr),
1025-
static_cast<double*>(result.request().ptr), buf.size, discont);
1024+
py::array_t<T> result(buf.shape);
1025+
numpy::unwrap(static_cast<const T*>(buf.ptr),
1026+
static_cast<T*>(result.request().ptr), buf.size, discont);
10261027
return result;
10271028
}
10281029

10291030
/// numpy.cumsum(a, axis=None) — 1D cumulative sum
1030-
inline py::array_t<double> cumsum(const py::array_t<double>& arr) {
1031+
template<typename T>
1032+
py::array_t<T> cumsum(const py::array_t<T>& arr) {
10311033
auto buf = arr.request();
1032-
py::array_t<double> result(buf.shape);
1033-
numpy::cumsum(static_cast<const double*>(buf.ptr),
1034-
static_cast<double*>(result.request().ptr), buf.size);
1034+
py::array_t<T> result(buf.shape);
1035+
numpy::cumsum(static_cast<const T*>(buf.ptr),
1036+
static_cast<T*>(result.request().ptr), buf.size);
10351037
return result;
10361038
}
10371039

10381040
/// numpy.squeeze(a, axis=None) — remove axes of length 1
1039-
inline py::array_t<double> squeeze(const py::array_t<double>& arr) {
1041+
template<typename T>
1042+
py::array_t<T> squeeze(const py::array_t<T>& arr) {
10401043
auto buf = arr.request();
10411044
std::vector<py::ssize_t> new_shape;
10421045
for (auto s : buf.shape)
10431046
if (s != 1) new_shape.push_back(s);
10441047
if (new_shape.empty()) new_shape.push_back(1);
1045-
py::array_t<double> result(new_shape);
1046-
std::memcpy(result.request().ptr, buf.ptr, buf.size * sizeof(double));
1048+
py::array_t<T> result(new_shape);
1049+
std::memcpy(result.request().ptr, buf.ptr, buf.size * sizeof(T));
10471050
return result;
10481051
}
10491052

10501053
/// numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False)
1051-
inline py::array_t<double> intersect1d(const py::array_t<double>& a, const py::array_t<double>& b) {
1054+
template<typename T>
1055+
py::array_t<T> intersect1d(const py::array_t<T>& a, const py::array_t<T>& b) {
10521056
auto ba = a.request(), bb = b.request();
10531057
auto inter = intersect1d(
1054-
static_cast<const double*>(ba.ptr), ba.size,
1055-
static_cast<const double*>(bb.ptr), bb.size);
1056-
return py::array_t<double>(inter.size(), inter.data());
1058+
static_cast<const T*>(ba.ptr), ba.size,
1059+
static_cast<const T*>(bb.ptr), bb.size);
1060+
return py::array_t<T>(inter.size(), inter.data());
10571061
}
10581062

10591063
// ============================================================================

tests/module.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,14 @@ PYBIND11_MODULE(numpycpp, m) {
217217
m.def("isin", static_cast<py::array_t<bool>(*)(const py::array_t<double>&, const std::vector<double>&)>(&numpy::isin));
218218
m.def("isin", static_cast<py::array_t<bool>(*)(const py::array_t<double>&, const std::vector<int>&)>(&numpy::isin));
219219
m.def("intersect1d", static_cast<py::array_t<double>(*)(const py::array_t<double>&, const py::array_t<double>&)>(&numpy::intersect1d));
220+
m.def("intersect1d", static_cast<py::array_t<float>(*)(const py::array_t<float>&, const py::array_t<float>&)>(&numpy::intersect1d));
220221
m.def("flatnonzero", static_cast<py::array_t<py::ssize_t>(*)(const py::array_t<double>&)>(&numpy::flatnonzero));
221222
m.def("unwrap", static_cast<py::array_t<double>(*)(const py::array_t<double>&, double)>(&numpy::unwrap), py::arg("arr"), py::arg("discont") = M_PI);
223+
m.def("unwrap", static_cast<py::array_t<float>(*)(const py::array_t<float>&, float)>(&numpy::unwrap), py::arg("arr"), py::arg("discont") = (float)M_PI);
222224
m.def("cumsum", static_cast<py::array_t<double>(*)(const py::array_t<double>&)>(&numpy::cumsum));
225+
m.def("cumsum", static_cast<py::array_t<float>(*)(const py::array_t<float>&)>(&numpy::cumsum));
223226
m.def("squeeze", static_cast<py::array_t<double>(*)(const py::array_t<double>&)>(&numpy::squeeze));
227+
m.def("squeeze", static_cast<py::array_t<float>(*)(const py::array_t<float>&)>(&numpy::squeeze));
224228

225229
// -- Interpolation -----------------------------------------------------
226230
m.def("interp", static_cast<py::array_t<double>(*)(const py::array_t<double>&, const py::array_t<double>&, const py::array_t<double>&)>(&numpy::interp));

tests/test_all.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -733,31 +733,38 @@ def test_flatnonzero(cpp):
733733
assert_bit_aligned(cpp.flatnonzero(a2), np.flatnonzero(a2), "flatnonzero zeros")
734734

735735
def test_unwrap(cpp):
736-
a = np.array([0.0, 0.5, 0.8, -0.9, -0.5, 0.2])
737-
assert_bit_aligned(cpp.unwrap(a), np.unwrap(a), "unwrap")
738-
a2 = np.array([0.0, 2.5, 5.0, -2.5, -5.0]) * np.pi
736+
for dt in [np.float64, np.float32]:
737+
a = np.array([0.0, 0.5, 0.8, -0.9, -0.5, 0.2], dtype=dt)
738+
assert_bit_aligned(cpp.unwrap(a), np.unwrap(a), f"unwrap_{dt}")
739+
# Large values: numpy uses float64 π internally even for float32 input,
740+
# so float32 unwrap is not bit-exact on the correction path. Test float64 only.
741+
a2 = np.array([0.0, 2.5, 5.0, -2.5, -5.0], dtype=np.float64) * np.pi
739742
assert_bit_aligned(cpp.unwrap(a2), np.unwrap(a2), "unwrap_large")
740743

741744
def test_cumsum(cpp):
742-
a = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
743-
assert_bit_aligned(cpp.cumsum(a), np.cumsum(a), "cumsum")
744-
a2 = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
745-
assert_bit_aligned(cpp.cumsum(a2), np.cumsum(a2), "cumsum_frac")
746-
a3 = np.array([-1.0, 2.0, -3.0, 4.0])
747-
assert_bit_aligned(cpp.cumsum(a3), np.cumsum(a3), "cumsum_neg")
745+
for dt in [np.float64, np.float32]:
746+
a = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=dt)
747+
assert_bit_aligned(cpp.cumsum(a), np.cumsum(a), f"cumsum_{dt}")
748+
a2 = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=dt)
749+
assert_bit_aligned(cpp.cumsum(a2), np.cumsum(a2), f"cumsum_frac_{dt}")
750+
a3 = np.array([-1.0, 2.0, -3.0, 4.0], dtype=dt)
751+
assert_bit_aligned(cpp.cumsum(a3), np.cumsum(a3), f"cumsum_neg_{dt}")
748752

749753
def test_squeeze(cpp):
750-
a = np.array([1.0, 2.0, 3.0]).reshape(3, 1)
751-
assert_bit_aligned(cpp.squeeze(a), np.squeeze(a), "squeeze_col")
752-
a2 = np.array([1.0, 2.0, 3.0]).reshape(1, 3)
753-
assert_bit_aligned(cpp.squeeze(a2), np.squeeze(a2), "squeeze_row")
754-
a3 = np.array([1.0, 2.0, 3.0, 4.0]).reshape(1, 2, 1, 2, 1)
755-
assert_bit_aligned(cpp.squeeze(a3), np.squeeze(a3), "squeeze_multi")
754+
for dt in [np.float64, np.float32]:
755+
a = np.array([1.0, 2.0, 3.0], dtype=dt).reshape(3, 1)
756+
assert_bit_aligned(cpp.squeeze(a), np.squeeze(a), f"squeeze_col_{dt}")
757+
a2 = np.array([1.0, 2.0, 3.0], dtype=dt).reshape(1, 3)
758+
assert_bit_aligned(cpp.squeeze(a2), np.squeeze(a2), f"squeeze_row_{dt}")
759+
a3 = np.array([1.0, 2.0, 3.0, 4.0], dtype=dt).reshape(1, 2, 1, 2, 1)
760+
assert_bit_aligned(cpp.squeeze(a3), np.squeeze(a3), f"squeeze_multi_{dt}")
756761

757762
def test_intersect1d(cpp):
758-
a, b = np.array([1.0, 2.0, 3.0, 4.0]), np.array([3.0, 4.0, 5.0, 6.0])
759-
cpp_r = np.sort(np.asarray(cpp.intersect1d(a, b)))
760-
assert_bit_aligned(cpp_r, np.intersect1d(a, b), "intersect1d")
763+
for dt in [np.float64, np.float32]:
764+
a = np.array([1.0, 2.0, 3.0, 4.0], dtype=dt)
765+
b = np.array([3.0, 4.0, 5.0, 6.0], dtype=dt)
766+
cpp_r = np.sort(np.asarray(cpp.intersect1d(a, b)))
767+
assert_bit_aligned(cpp_r, np.intersect1d(a, b), f"intersect1d_{dt}")
761768

762769
def test_interp_basic(cpp):
763770
xp = np.array([0.0, 1.0, 2.0, 3.0, 4.0])

0 commit comments

Comments
 (0)