|
11 | 11 | #include "../numpy/core.h" |
12 | 12 | #include <vector> |
13 | 13 | #include <cstring> |
| 14 | +#include <cstdint> |
14 | 15 |
|
15 | 16 | namespace py = pybind11; |
16 | 17 |
|
@@ -116,30 +117,132 @@ inline py::array ones_like(const py::array& arr, const std::string& dtype) { |
116 | 117 | inline py::array astype(const py::array& arr, const std::string& dtype) { |
117 | 118 | auto buf = arr.request(); |
118 | 119 | auto dt = arr.dtype(); |
| 120 | + |
119 | 121 | // float64 input |
120 | 122 | if (dt.is(py::dtype::of<double>())) { |
121 | | - if (dtype == "int" || dtype == "int32" || dtype == "int64") { |
122 | | - py::array_t<int> result(buf.shape); |
123 | | - astype<int, double>(static_cast<const double*>(buf.ptr), |
124 | | - static_cast<int*>(result.request().ptr), buf.size); |
125 | | - return result; |
| 123 | + auto* src = static_cast<const double*>(buf.ptr); |
| 124 | + if (dtype == "float32" || dtype == "float") { |
| 125 | + py::array_t<float> r(buf.shape); |
| 126 | + astype<float, double>(src, static_cast<float*>(r.request().ptr), buf.size); |
| 127 | + return r; |
| 128 | + } |
| 129 | + if (dtype == "int" || dtype == "int32") { |
| 130 | + py::array_t<int> r(buf.shape); |
| 131 | + astype<int, double>(src, static_cast<int*>(r.request().ptr), buf.size); |
| 132 | + return r; |
| 133 | + } |
| 134 | + if (dtype == "int64") { |
| 135 | + py::array_t<int64_t> r(buf.shape); |
| 136 | + astype<int64_t, double>(src, static_cast<int64_t*>(r.request().ptr), buf.size); |
| 137 | + return r; |
| 138 | + } |
| 139 | + if (dtype == "bool") { |
| 140 | + py::array_t<bool> r(buf.shape); |
| 141 | + astype<bool, double>(src, static_cast<bool*>(r.request().ptr), buf.size); |
| 142 | + return r; |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + // float32 input |
| 147 | + if (dt.is(py::dtype::of<float>())) { |
| 148 | + auto* src = static_cast<const float*>(buf.ptr); |
| 149 | + if (dtype == "float64" || dtype == "double") { |
| 150 | + py::array_t<double> r(buf.shape); |
| 151 | + astype<double, float>(src, static_cast<double*>(r.request().ptr), buf.size); |
| 152 | + return r; |
| 153 | + } |
| 154 | + if (dtype == "int" || dtype == "int32") { |
| 155 | + py::array_t<int> r(buf.shape); |
| 156 | + astype<int, float>(src, static_cast<int*>(r.request().ptr), buf.size); |
| 157 | + return r; |
| 158 | + } |
| 159 | + if (dtype == "int64") { |
| 160 | + py::array_t<int64_t> r(buf.shape); |
| 161 | + astype<int64_t, float>(src, static_cast<int64_t*>(r.request().ptr), buf.size); |
| 162 | + return r; |
126 | 163 | } |
127 | 164 | if (dtype == "bool") { |
128 | | - py::array_t<bool> result(buf.shape); |
129 | | - astype<bool, double>(static_cast<const double*>(buf.ptr), |
130 | | - static_cast<bool*>(result.request().ptr), buf.size); |
131 | | - return result; |
| 165 | + py::array_t<bool> r(buf.shape); |
| 166 | + astype<bool, float>(src, static_cast<bool*>(r.request().ptr), buf.size); |
| 167 | + return r; |
132 | 168 | } |
133 | 169 | } |
134 | | - // int input |
| 170 | + |
| 171 | + // int32 input |
135 | 172 | if (dt.is(py::dtype::of<int>())) { |
| 173 | + auto* src = static_cast<const int*>(buf.ptr); |
| 174 | + if (dtype == "float64" || dtype == "double") { |
| 175 | + py::array_t<double> r(buf.shape); |
| 176 | + astype<double, int>(src, static_cast<double*>(r.request().ptr), buf.size); |
| 177 | + return r; |
| 178 | + } |
| 179 | + if (dtype == "float32" || dtype == "float") { |
| 180 | + py::array_t<float> r(buf.shape); |
| 181 | + astype<float, int>(src, static_cast<float*>(r.request().ptr), buf.size); |
| 182 | + return r; |
| 183 | + } |
| 184 | + if (dtype == "int64") { |
| 185 | + py::array_t<int64_t> r(buf.shape); |
| 186 | + astype<int64_t, int>(src, static_cast<int64_t*>(r.request().ptr), buf.size); |
| 187 | + return r; |
| 188 | + } |
136 | 189 | if (dtype == "bool") { |
137 | | - py::array_t<bool> result(buf.shape); |
138 | | - astype<bool, int>(static_cast<const int*>(buf.ptr), |
139 | | - static_cast<bool*>(result.request().ptr), buf.size); |
140 | | - return result; |
| 190 | + py::array_t<bool> r(buf.shape); |
| 191 | + astype<bool, int>(src, static_cast<bool*>(r.request().ptr), buf.size); |
| 192 | + return r; |
141 | 193 | } |
142 | 194 | } |
| 195 | + |
| 196 | + // int64 input |
| 197 | + if (dt.is(py::dtype::of<int64_t>())) { |
| 198 | + auto* src = static_cast<const int64_t*>(buf.ptr); |
| 199 | + if (dtype == "float64" || dtype == "double") { |
| 200 | + py::array_t<double> r(buf.shape); |
| 201 | + astype<double, int64_t>(src, static_cast<double*>(r.request().ptr), buf.size); |
| 202 | + return r; |
| 203 | + } |
| 204 | + if (dtype == "float32" || dtype == "float") { |
| 205 | + py::array_t<float> r(buf.shape); |
| 206 | + astype<float, int64_t>(src, static_cast<float*>(r.request().ptr), buf.size); |
| 207 | + return r; |
| 208 | + } |
| 209 | + if (dtype == "int" || dtype == "int32") { |
| 210 | + py::array_t<int> r(buf.shape); |
| 211 | + astype<int, int64_t>(src, static_cast<int*>(r.request().ptr), buf.size); |
| 212 | + return r; |
| 213 | + } |
| 214 | + if (dtype == "bool") { |
| 215 | + py::array_t<bool> r(buf.shape); |
| 216 | + astype<bool, int64_t>(src, static_cast<bool*>(r.request().ptr), buf.size); |
| 217 | + return r; |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + // bool input |
| 222 | + if (dt.is(py::dtype::of<bool>())) { |
| 223 | + auto* src = static_cast<const bool*>(buf.ptr); |
| 224 | + if (dtype == "float64" || dtype == "double") { |
| 225 | + py::array_t<double> r(buf.shape); |
| 226 | + astype<double, bool>(src, static_cast<double*>(r.request().ptr), buf.size); |
| 227 | + return r; |
| 228 | + } |
| 229 | + if (dtype == "float32" || dtype == "float") { |
| 230 | + py::array_t<float> r(buf.shape); |
| 231 | + astype<float, bool>(src, static_cast<float*>(r.request().ptr), buf.size); |
| 232 | + return r; |
| 233 | + } |
| 234 | + if (dtype == "int" || dtype == "int32") { |
| 235 | + py::array_t<int> r(buf.shape); |
| 236 | + astype<int, bool>(src, static_cast<int*>(r.request().ptr), buf.size); |
| 237 | + return r; |
| 238 | + } |
| 239 | + if (dtype == "int64") { |
| 240 | + py::array_t<int64_t> r(buf.shape); |
| 241 | + astype<int64_t, bool>(src, static_cast<int64_t*>(r.request().ptr), buf.size); |
| 242 | + return r; |
| 243 | + } |
| 244 | + } |
| 245 | + |
143 | 246 | throw std::runtime_error("astype: unsupported conversion " + std::string(py::str(dt)) + " -> " + dtype); |
144 | 247 | } |
145 | 248 |
|
|
0 commit comments