@@ -74,63 +74,73 @@ inline py::array_t<double> full(const std::vector<py::ssize_t>& shape, double fi
7474 return result;
7575}
7676
77- // Bool specializations
78- // NOTE: _bool suffix — dtype-specific wrappers; pybind11 cannot deduce template
79- // argument from a Python dtype keyword, so each dtype needs its own binding.
80- inline py::array_t <bool > full_like_bool (const py::array_t <double >& arr, bool fill_value) {
77+ // Bool specializations — return-type disambiguation via dtype parameter
78+ // where pybind11 cannot distinguish overloads by return type alone.
79+ //
80+ // full_like(arr, bool_val): bool fill_value naturally disambiguates from
81+ // the template full_like(arr, T val) where T=double/float.
82+ // zeros_like(arr, dtype_str) / ones_like(arr, dtype_str): string dtype
83+ // parameter mirrors numpy's dtype= kwarg for disambiguation.
84+ inline py::array_t <bool > full_like (const py::array_t <double >& arr, bool fill_value) {
8185 auto buf = arr.request ();
8286 py::array_t <bool > result (buf.shape );
8387 std::fill_n (static_cast <bool *>(result.request ().ptr ), buf.size , fill_value);
8488 return result;
8589}
8690
87- inline py::array_t < bool > zeros_like_bool (const py::array_t < double > & arr) {
91+ inline py::array zeros_like (const py::array & arr, const std::string& dtype ) {
8892 auto buf = arr.request ();
89- py::array_t <bool > result (buf.shape );
90- std::fill_n (static_cast <bool *>(result.request ().ptr ), buf.size , false );
91- return result;
93+ if (dtype == " bool" ) {
94+ py::array_t <bool > result (buf.shape );
95+ std::fill_n (static_cast <bool *>(result.request ().ptr ), buf.size , false );
96+ return result;
97+ }
98+ throw std::runtime_error (" unsupported dtype: " + dtype);
9299}
93100
94- inline py::array_t < bool > ones_like_bool (const py::array_t < double > & arr) {
101+ inline py::array ones_like (const py::array & arr, const std::string& dtype ) {
95102 auto buf = arr.request ();
96- py::array_t <bool > result (buf.shape );
97- std::fill_n (static_cast <bool *>(result.request ().ptr ), buf.size , true );
98- return result;
103+ if (dtype == " bool" ) {
104+ py::array_t <bool > result (buf.shape );
105+ std::fill_n (static_cast <bool *>(result.request ().ptr ), buf.size , true );
106+ return result;
107+ }
108+ throw std::runtime_error (" unsupported dtype: " + dtype);
99109}
100110
101111// ============================================================================
102112// astype — ndarray.astype(dtype, order='K', casting='unsafe', subok=True, copy=True)
103- //
104- // NOTE: wrappers use distinct names (astype_int, astype_bool, astype_bool_from_int)
105- // instead of a single "astype" because pybind11 cannot resolve overloads that differ
106- // only by return type. Each (Tout, Tin) combination needs its own Python binding.
107113// ============================================================================
108114
109- // / ndarray.astype(int) — float64 → int
110- inline py::array_t <int > astype_int (const py::array_t <double >& arr) {
111- auto buf = arr.request ();
112- py::array_t <int > result (buf.shape );
113- astype<int , double >(static_cast <const double *>(buf.ptr ),
114- static_cast <int *>(result.request ().ptr ), buf.size );
115- return result;
116- }
117-
118- // / ndarray.astype(bool) — float64 → bool
119- inline py::array_t <bool > astype_bool (const py::array_t <double >& arr) {
120- auto buf = arr.request ();
121- py::array_t <bool > result (buf.shape );
122- astype<bool , double >(static_cast <const double *>(buf.ptr ),
123- static_cast <bool *>(result.request ().ptr ), buf.size );
124- return result;
125- }
126-
127- // / ndarray.astype(bool) — int → bool
128- inline py::array_t <bool > astype_bool_from_int (const py::array_t <int >& arr) {
129- auto buf = arr.request ();
130- py::array_t <bool > result (buf.shape );
131- astype<bool , int >(static_cast <const int *>(buf.ptr ),
132- static_cast <bool *>(result.request ().ptr ), buf.size );
133- return result;
115+ // / ndarray.astype(dtype) — unified dtype dispatch
116+ inline py::array astype (const py::array& arr, const std::string& dtype) {
117+ auto buf = arr.request ();
118+ auto dt = arr.dtype ();
119+ // float64 input
120+ if (dt.is (py::dtype::of<double >())) {
121+ if (dtype == " int" || dtype == " int32" || dtype == " int64" ) {
122+ py::array_t <int > result (buf.shape );
123+ astype<int , double >(static_cast <const double *>(buf.ptr ),
124+ static_cast <int *>(result.request ().ptr ), buf.size );
125+ return result;
126+ }
127+ if (dtype == " bool" ) {
128+ py::array_t <bool > result (buf.shape );
129+ astype<bool , double >(static_cast <const double *>(buf.ptr ),
130+ static_cast <bool *>(result.request ().ptr ), buf.size );
131+ return result;
132+ }
133+ }
134+ // int input
135+ if (dt.is (py::dtype::of<int >())) {
136+ if (dtype == " bool" ) {
137+ py::array_t <bool > result (buf.shape );
138+ astype<bool , int >(static_cast <const int *>(buf.ptr ),
139+ static_cast <bool *>(result.request ().ptr ), buf.size );
140+ return result;
141+ }
142+ }
143+ throw std::runtime_error (" astype: unsupported conversion " + std::string (py::str (dt)) + " -> " + dtype);
134144}
135145
136146// / float64 → float32 → float64 roundtrip (precision testing helper)
0 commit comments