Skip to content

Commit 0fee119

Browse files
author
peng.li24
committed
fix: rename _bool/_t/astype_* wrappers to numpy-consistent names
- full_like_bool(arr, bool_val) → full_like(arr, bool_val) (bool fill_value naturally disambiguates from double/float overloads) - zeros_like_bool(arr) → zeros_like(arr, dtype="bool") (string dtype mirrors numpy's dtype= kwarg for disambiguation) - ones_like_bool(arr) → ones_like(arr, dtype="bool") - astype_int/astype_bool/astype_bool_from_int → astype(arr, dtype) (unified runtime dtype dispatch: double↔int, double↔bool, int↔bool) - Remove zeros_t/ones_t/full_t _t suffix templates (prior commit) All Python-exposed names now match numpy exactly.
1 parent b70efd3 commit 0fee119

3 files changed

Lines changed: 65 additions & 54 deletions

File tree

pycpp/core_py.h

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -74,63 +74,73 @@ inline py::array_t<double> full(const std::vector<py::ssize_t>& shape, double fi
7474
return result;
7575
}
7676

77-
// Bool specializations
78-
// NOTE: _bool suffix — dtype-specific wrappers; pybind11 cannot deduce template
79-
// argument from a Python dtype keyword, so each dtype needs its own binding.
80-
inline py::array_t<bool> full_like_bool(const py::array_t<double>& arr, bool fill_value) {
77+
// Bool specializations — return-type disambiguation via dtype parameter
78+
// where pybind11 cannot distinguish overloads by return type alone.
79+
//
80+
// full_like(arr, bool_val): bool fill_value naturally disambiguates from
81+
// the template full_like(arr, T val) where T=double/float.
82+
// zeros_like(arr, dtype_str) / ones_like(arr, dtype_str): string dtype
83+
// parameter mirrors numpy's dtype= kwarg for disambiguation.
84+
inline py::array_t<bool> full_like(const py::array_t<double>& arr, bool fill_value) {
8185
auto buf = arr.request();
8286
py::array_t<bool> result(buf.shape);
8387
std::fill_n(static_cast<bool*>(result.request().ptr), buf.size, fill_value);
8488
return result;
8589
}
8690

87-
inline py::array_t<bool> zeros_like_bool(const py::array_t<double>& arr) {
91+
inline py::array zeros_like(const py::array& arr, const std::string& dtype) {
8892
auto buf = arr.request();
89-
py::array_t<bool> result(buf.shape);
90-
std::fill_n(static_cast<bool*>(result.request().ptr), buf.size, false);
91-
return result;
93+
if (dtype == "bool") {
94+
py::array_t<bool> result(buf.shape);
95+
std::fill_n(static_cast<bool*>(result.request().ptr), buf.size, false);
96+
return result;
97+
}
98+
throw std::runtime_error("unsupported dtype: " + dtype);
9299
}
93100

94-
inline py::array_t<bool> ones_like_bool(const py::array_t<double>& arr) {
101+
inline py::array ones_like(const py::array& arr, const std::string& dtype) {
95102
auto buf = arr.request();
96-
py::array_t<bool> result(buf.shape);
97-
std::fill_n(static_cast<bool*>(result.request().ptr), buf.size, true);
98-
return result;
103+
if (dtype == "bool") {
104+
py::array_t<bool> result(buf.shape);
105+
std::fill_n(static_cast<bool*>(result.request().ptr), buf.size, true);
106+
return result;
107+
}
108+
throw std::runtime_error("unsupported dtype: " + dtype);
99109
}
100110

101111
// ============================================================================
102112
// astype — ndarray.astype(dtype, order='K', casting='unsafe', subok=True, copy=True)
103-
//
104-
// NOTE: wrappers use distinct names (astype_int, astype_bool, astype_bool_from_int)
105-
// instead of a single "astype" because pybind11 cannot resolve overloads that differ
106-
// only by return type. Each (Tout, Tin) combination needs its own Python binding.
107113
// ============================================================================
108114

109-
/// ndarray.astype(int) — float64 → int
110-
inline py::array_t<int> astype_int(const py::array_t<double>& arr) {
111-
auto buf = arr.request();
112-
py::array_t<int> result(buf.shape);
113-
astype<int, double>(static_cast<const double*>(buf.ptr),
114-
static_cast<int*>(result.request().ptr), buf.size);
115-
return result;
116-
}
117-
118-
/// ndarray.astype(bool) — float64 → bool
119-
inline py::array_t<bool> astype_bool(const py::array_t<double>& arr) {
120-
auto buf = arr.request();
121-
py::array_t<bool> result(buf.shape);
122-
astype<bool, double>(static_cast<const double*>(buf.ptr),
123-
static_cast<bool*>(result.request().ptr), buf.size);
124-
return result;
125-
}
126-
127-
/// ndarray.astype(bool) — int → bool
128-
inline py::array_t<bool> astype_bool_from_int(const py::array_t<int>& arr) {
129-
auto buf = arr.request();
130-
py::array_t<bool> result(buf.shape);
131-
astype<bool, int>(static_cast<const int*>(buf.ptr),
132-
static_cast<bool*>(result.request().ptr), buf.size);
133-
return result;
115+
/// ndarray.astype(dtype) — unified dtype dispatch
116+
inline py::array astype(const py::array& arr, const std::string& dtype) {
117+
auto buf = arr.request();
118+
auto dt = arr.dtype();
119+
// float64 input
120+
if (dt.is(py::dtype::of<double>())) {
121+
if (dtype == "int" || dtype == "int32" || dtype == "int64") {
122+
py::array_t<int> result(buf.shape);
123+
astype<int, double>(static_cast<const double*>(buf.ptr),
124+
static_cast<int*>(result.request().ptr), buf.size);
125+
return result;
126+
}
127+
if (dtype == "bool") {
128+
py::array_t<bool> result(buf.shape);
129+
astype<bool, double>(static_cast<const double*>(buf.ptr),
130+
static_cast<bool*>(result.request().ptr), buf.size);
131+
return result;
132+
}
133+
}
134+
// int input
135+
if (dt.is(py::dtype::of<int>())) {
136+
if (dtype == "bool") {
137+
py::array_t<bool> result(buf.shape);
138+
astype<bool, int>(static_cast<const int*>(buf.ptr),
139+
static_cast<bool*>(result.request().ptr), buf.size);
140+
return result;
141+
}
142+
}
143+
throw std::runtime_error("astype: unsupported conversion " + std::string(py::str(dt)) + " -> " + dtype);
134144
}
135145

136146
/// float64 → float32 → float64 roundtrip (precision testing helper)

tests/module.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ PYBIND11_MODULE(numpycpp, m) {
6565
m.def("full_like", static_cast<py::array_t<float>(*)(const py::array_t<float>&, float)>(&numpy::full_like));
6666
// NOTE: _bool suffix — dtype-specific wrappers needed because pybind11
6767
// cannot deduce template argument from a Python dtype keyword.
68-
m.def("full_like_bool", &numpy::full_like_bool);
69-
m.def("zeros_like_bool", &numpy::zeros_like_bool);
70-
m.def("ones_like_bool", &numpy::ones_like_bool);
68+
m.def("full_like", static_cast<py::array_t<bool>(*)(const py::array_t<double>&, bool)>(&numpy::full_like));
69+
m.def("zeros_like", static_cast<py::array(*)(const py::array&, const std::string&)>(&numpy::zeros_like),
70+
py::arg("arr"), py::arg("dtype"));
71+
m.def("ones_like", static_cast<py::array(*)(const py::array&, const std::string&)>(&numpy::ones_like),
72+
py::arg("arr"), py::arg("dtype"));
7173
m.def("zeros", &numpy::zeros);
7274
m.def("ones", &numpy::ones);
7375
m.def("full", static_cast<py::array_t<double>(*)(const std::vector<py::ssize_t>&, double)>(&numpy::full));
@@ -77,9 +79,8 @@ PYBIND11_MODULE(numpycpp, m) {
7779
// single "astype" — pybind11 cannot resolve overloads that differ only
7880
// by return type (e.g. astype<double> vs astype<bool> both take
7981
// py::array_t<double>). Each dtype combo needs a distinct Python name.
80-
m.def("astype_int", static_cast<py::array_t<int>(*)(const py::array_t<double>&)>(&numpy::astype_int));
81-
m.def("astype_bool", static_cast<py::array_t<bool>(*)(const py::array_t<double>&)>(&numpy::astype_bool));
82-
m.def("astype_bool_from_int", static_cast<py::array_t<bool>(*)(const py::array_t<int>&)>(&numpy::astype_bool_from_int));
82+
m.def("astype", static_cast<py::array(*)(const py::array&, const std::string&)>(&numpy::astype),
83+
py::arg("arr"), py::arg("dtype"));
8384
m.def("truncate_to_float32", static_cast<py::array_t<double>(*)(const py::array_t<double>&)>(&numpy::truncate_to_float32));
8485

8586
// -- Element-wise math -------------------------------------------------

tests/test_all.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,16 +443,16 @@ def test_full(shape, fill_val, cpp):
443443
@pytest.mark.parametrize("value", [True, False])
444444
def test_full_like_bool(value, cpp):
445445
a = random_array((3, 4))
446-
assert_bit_aligned(cpp.full_like_bool(a, value),
447-
np.full_like(a, value, dtype=bool), f"full_like_bool({value})")
446+
assert_bit_aligned(cpp.full_like(a, value),
447+
np.full_like(a, value, dtype=bool), f"full_like({value})")
448448

449449
def test_zeros_like_bool(cpp):
450450
a = random_array((3, 4))
451-
assert_bit_aligned(cpp.zeros_like_bool(a), np.zeros_like(a, dtype=bool), "zeros_like_bool")
451+
assert_bit_aligned(cpp.zeros_like(a, "bool"), np.zeros_like(a, dtype=bool), "zeros_like")
452452

453453
def test_ones_like_bool(cpp):
454454
a = random_array((3, 4))
455-
assert_bit_aligned(cpp.ones_like_bool(a), np.ones_like(a, dtype=bool), "ones_like_bool")
455+
assert_bit_aligned(cpp.ones_like(a, "bool"), np.ones_like(a, dtype=bool), "ones_like")
456456

457457

458458
# ============================================================================
@@ -461,15 +461,15 @@ def test_ones_like_bool(cpp):
461461

462462
def test_astype_int(cpp):
463463
a = np.array([[1.7, 2.3], [-3.9, 0.5]], dtype=np.float64)
464-
assert_bit_aligned(cpp.astype_int(a), a.astype(np.int32), "astype_int")
464+
assert_bit_aligned(cpp.astype(a, "int"), a.astype(np.int32), "astype_int")
465465

466466
def test_astype_bool(cpp):
467467
a = np.array([[0.0, 1.0, -1.0], [3.14, 0.0, 0.0]], dtype=np.float64)
468-
assert_bit_aligned(cpp.astype_bool(a), a.astype(bool), "astype_bool")
468+
assert_bit_aligned(cpp.astype(a, "bool"), a.astype(bool), "astype_bool")
469469

470470
def test_astype_bool_from_int(cpp):
471471
a = np.array([[0, 1, -1], [42, 0, 0]], dtype=np.int32)
472-
assert_bit_aligned(cpp.astype_bool_from_int(a), a.astype(bool), "astype_bool_from_int")
472+
assert_bit_aligned(cpp.astype(a, "bool"), a.astype(bool), "astype_bool_from_int")
473473

474474
def test_truncate_to_float32(cpp):
475475
a = np.array([1.0 / 3.0, np.pi, np.sqrt(2.0)], dtype=np.float64)

0 commit comments

Comments
 (0)