Skip to content

Commit 4ee900e

Browse files
author
peng.li24
committed
refactor(pycpp): split monolithic core_py.h into 6 topic modules
pycpp/ now mirrors the numpy/ module structure: init_py.h — zeros_like, ones_like, full_like, zeros, ones, full elementwise_py.h — sqrt/exp/sin/…, comparison, logical, isnan/isinf, astype, truncate_to_float32, power, clip, hypot, arctan2, maximum, minimum reduce_py.h — sum, mean, max, min, any, all, std, var, cumsum, mean_axis manipulation_py.h— diff, stack, concatenate, vstack, hstack, transpose, flatten, squeeze, where, roll, flip, repeat, tile, argsort, argmax, argmin, slice (1D+ND), slice_assign, take_cols, take, compress, put, putmask io_py.h — isin, flatnonzero, intersect1d, interp, unwrap, asarray, array, array_get, to_vector, get_array, set_array linalg_py.h — dot, norm, matmul (2D/3D/mv/vm), einsum (absorbed from former einsum_py.h) pycpp.h — umbrella: #includes all 6 modules Backward-compat shims: core_py.h → #include "pycpp.h" einsum_py.h→ #include "pycpp.h" tests/module.cpp: use single #include "pycpp.h" pycpp/pyproject.toml: deleted (stale Python packaging artefact) All 900 bit-exact tests pass.
1 parent b5e3248 commit 4ee900e

11 files changed

Lines changed: 1340 additions & 1529 deletions

File tree

pycpp/core_py.h

Lines changed: 9 additions & 1459 deletions
Large diffs are not rendered by default.

pycpp/einsum_py.h

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,8 @@
1-
// Pybind11 wrappers for einsum native functions.
2-
1+
// ════════════════════════════════════════════════════════════════════════════
2+
// numpycpp — pycpp/einsum_py.h [BACKWARD-COMPAT SHIM]
3+
// Legacy header kept for backward compatibility.
4+
// einsum wrapper is now in pycpp/linalg_py.h.
5+
// This shim simply pulls in the umbrella header.
6+
// ════════════════════════════════════════════════════════════════════════════
37
#pragma once
4-
5-
#include <pybind11/pybind11.h>
6-
#include <pybind11/numpy.h>
7-
#include "../numpy/numpy.h"
8-
#include <vector>
9-
10-
namespace py = pybind11;
11-
12-
namespace numpy {
13-
14-
/// numpy.einsum(subscripts, *operands, out=None, dtype=None, order='K',
15-
/// casting='safe', optimize=False)
16-
// Currently supports 2-operand patterns only.
17-
template<typename T>
18-
py::array_t<T> einsum(const std::string& subscripts,
19-
const py::array_t<T>& a,
20-
const py::array_t<T>& b) {
21-
auto bufa = a.request(), bufb = b.request();
22-
23-
std::vector<ptrdiff_t> a_shape(bufa.shape.begin(), bufa.shape.end());
24-
std::vector<ptrdiff_t> b_shape(bufb.shape.begin(), bufb.shape.end());
25-
26-
auto out_shape = einsum_detail::einsum_output_shape(
27-
subscripts, a_shape.data(), static_cast<int>(a_shape.size()),
28-
b_shape.data(), static_cast<int>(b_shape.size()));
29-
30-
std::vector<py::ssize_t> py_out_shape(out_shape.begin(), out_shape.end());
31-
py::array_t<T> result(py_out_shape);
32-
33-
einsum_detail::einsum(
34-
subscripts,
35-
static_cast<const T*>(bufa.ptr), a_shape.data(), static_cast<int>(a_shape.size()),
36-
static_cast<const T*>(bufb.ptr), b_shape.data(), static_cast<int>(b_shape.size()),
37-
static_cast<T*>(result.request().ptr));
38-
39-
return result;
40-
}
41-
42-
} // namespace numpy
8+
#include "pycpp.h"

pycpp/elementwise_py.h

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
// ════════════════════════════════════════════════════════════════════════════
2+
// numpycpp — pycpp/elementwise_py.h [PUBLIC HEADER]
3+
// Pybind11 wrappers: element-wise operations and type conversion.
4+
// Unary: sqrt abs exp log sin cos tan cbrt expm1 log1p log10 log2
5+
// arcsin arccos arctan round floor ceil degrees radians sign
6+
// power clip
7+
// Binary: hypot arctan2 maximum minimum
8+
// Comparison: greater less equal greater_equal less_equal not_equal
9+
// Logical: logical_and logical_or logical_not logical_xor
10+
// Special: isnan isinf isfinite
11+
// Type conversion: astype truncate_to_float32
12+
// ════════════════════════════════════════════════════════════════════════════
13+
#pragma once
14+
15+
#include <pybind11/pybind11.h>
16+
#include <pybind11/numpy.h>
17+
#include "../numpy/numpy.h"
18+
#include <algorithm>
19+
#include <stdexcept>
20+
#include <cstdint>
21+
22+
namespace py = pybind11;
23+
24+
namespace numpy {
25+
26+
// ============================================================================
27+
// Unary element-wise — template macro
28+
// ============================================================================
29+
#define DEF_ELEMWISE(name) \
30+
template<typename T> \
31+
py::array_t<T> name(const py::array_t<T>& arr) { \
32+
auto buf = arr.request(); \
33+
py::array_t<T> result(buf.shape); \
34+
numpy::name(static_cast<const T*>(buf.ptr), \
35+
static_cast<T*>(result.request().ptr), buf.size); \
36+
return result; \
37+
}
38+
39+
DEF_ELEMWISE(sqrt)
40+
DEF_ELEMWISE(abs)
41+
DEF_ELEMWISE(exp)
42+
DEF_ELEMWISE(log)
43+
DEF_ELEMWISE(sin)
44+
DEF_ELEMWISE(cos)
45+
DEF_ELEMWISE(tan)
46+
DEF_ELEMWISE(cbrt)
47+
DEF_ELEMWISE(expm1)
48+
DEF_ELEMWISE(log1p)
49+
DEF_ELEMWISE(log10)
50+
DEF_ELEMWISE(log2)
51+
DEF_ELEMWISE(arcsin)
52+
DEF_ELEMWISE(arccos)
53+
DEF_ELEMWISE(arctan)
54+
DEF_ELEMWISE(round)
55+
DEF_ELEMWISE(floor)
56+
DEF_ELEMWISE(ceil)
57+
DEF_ELEMWISE(degrees)
58+
DEF_ELEMWISE(radians)
59+
DEF_ELEMWISE(sign)
60+
#undef DEF_ELEMWISE
61+
62+
/// numpy.power(x1, x2) — scalar exponent
63+
template<typename T>
64+
py::array_t<T> power(const py::array_t<T>& arr, T exponent) {
65+
auto buf = arr.request();
66+
py::array_t<T> result(buf.shape);
67+
numpy::power(static_cast<const T*>(buf.ptr),
68+
static_cast<T*>(result.request().ptr), buf.size, exponent);
69+
return result;
70+
}
71+
72+
/// numpy.clip(a, a_min, a_max)
73+
template<typename T>
74+
py::array_t<T> clip(const py::array_t<T>& arr, T min_val, T max_val) {
75+
auto buf = arr.request();
76+
py::array_t<T> result(buf.shape);
77+
numpy::clip(static_cast<const T*>(buf.ptr),
78+
static_cast<T*>(result.request().ptr), buf.size, min_val, max_val);
79+
return result;
80+
}
81+
82+
// ============================================================================
83+
// Binary element-wise
84+
// ============================================================================
85+
86+
/// numpy.hypot(x1, x2)
87+
template<typename T>
88+
py::array_t<T> hypot(const py::array_t<T>& a, const py::array_t<T>& b) {
89+
auto ba = a.request(), bb = b.request();
90+
py::array_t<T> result(ba.shape);
91+
numpy::hypot(static_cast<const T*>(ba.ptr), static_cast<const T*>(bb.ptr),
92+
static_cast<T*>(result.request().ptr), std::min(ba.size, bb.size));
93+
return result;
94+
}
95+
96+
/// numpy.arctan2(x1, x2) — array-array
97+
template<typename T>
98+
py::array_t<T> arctan2(const py::array_t<T>& a, const py::array_t<T>& b) {
99+
auto ba = a.request(), bb = b.request();
100+
py::array_t<T> result(ba.shape);
101+
numpy::arctan2(static_cast<const T*>(ba.ptr), static_cast<const T*>(bb.ptr),
102+
static_cast<T*>(result.request().ptr), std::min(ba.size, bb.size));
103+
return result;
104+
}
105+
106+
/// numpy.arctan2(x1, x2) — array-scalar
107+
template<typename T>
108+
py::array_t<T> arctan2(const py::array_t<T>& a, T b) {
109+
auto buf = a.request();
110+
py::array_t<T> result(buf.shape);
111+
numpy::arctan2(static_cast<const T*>(buf.ptr),
112+
static_cast<T*>(result.request().ptr), buf.size, b);
113+
return result;
114+
}
115+
116+
/// numpy.maximum(x1, x2) — array-array
117+
template<typename T>
118+
py::array_t<T> maximum(const py::array_t<T>& a, const py::array_t<T>& b) {
119+
auto ba = a.request(), bb = b.request();
120+
py::array_t<T> result(ba.shape);
121+
numpy::maximum(static_cast<const T*>(ba.ptr), static_cast<const T*>(bb.ptr),
122+
static_cast<T*>(result.request().ptr), std::min(ba.size, bb.size));
123+
return result;
124+
}
125+
126+
/// numpy.maximum(x1, x2) — array-scalar
127+
template<typename T>
128+
py::array_t<T> maximum(const py::array_t<T>& a, T b) {
129+
auto buf = a.request();
130+
py::array_t<T> result(buf.shape);
131+
numpy::maximum(static_cast<const T*>(buf.ptr),
132+
static_cast<T*>(result.request().ptr), buf.size, b);
133+
return result;
134+
}
135+
136+
/// numpy.minimum(x1, x2) — array-array
137+
template<typename T>
138+
py::array_t<T> minimum(const py::array_t<T>& a, const py::array_t<T>& b) {
139+
auto ba = a.request(), bb = b.request();
140+
py::array_t<T> result(ba.shape);
141+
numpy::minimum(static_cast<const T*>(ba.ptr), static_cast<const T*>(bb.ptr),
142+
static_cast<T*>(result.request().ptr), std::min(ba.size, bb.size));
143+
return result;
144+
}
145+
146+
/// numpy.minimum(x1, x2) — array-scalar
147+
template<typename T>
148+
py::array_t<T> minimum(const py::array_t<T>& a, T b) {
149+
auto buf = a.request();
150+
py::array_t<T> result(buf.shape);
151+
numpy::minimum(static_cast<const T*>(buf.ptr),
152+
static_cast<T*>(result.request().ptr), buf.size, b);
153+
return result;
154+
}
155+
156+
// ============================================================================
157+
// Comparison — T → bool
158+
// ============================================================================
159+
#define DEF_COMPARE(name) \
160+
template<typename T> \
161+
py::array_t<bool> name(const py::array_t<T>& a, T b) { \
162+
auto buf = a.request(); \
163+
py::array_t<bool> result(buf.shape); \
164+
numpy::name(static_cast<const T*>(buf.ptr), \
165+
static_cast<bool*>(result.request().ptr), buf.size, b); \
166+
return result; \
167+
}
168+
169+
DEF_COMPARE(greater)
170+
DEF_COMPARE(less)
171+
DEF_COMPARE(equal)
172+
DEF_COMPARE(greater_equal)
173+
DEF_COMPARE(less_equal)
174+
#undef DEF_COMPARE
175+
176+
/// numpy.not_equal(x1, x2) — scalar
177+
template<typename T>
178+
py::array_t<bool> not_equal(const py::array_t<T>& a, T b) {
179+
auto buf = a.request();
180+
py::array_t<bool> result(buf.shape);
181+
numpy::not_equal_scalar(static_cast<const T*>(buf.ptr),
182+
static_cast<bool*>(result.request().ptr), buf.size, b);
183+
return result;
184+
}
185+
186+
/// numpy.not_equal(x1, x2) — array
187+
template<typename T>
188+
py::array_t<bool> not_equal(const py::array_t<T>& a, const py::array_t<T>& b) {
189+
auto ba = a.request(), bb = b.request();
190+
py::array_t<bool> result(ba.shape);
191+
numpy::not_equal_array(static_cast<const T*>(ba.ptr),
192+
static_cast<const T*>(bb.ptr),
193+
static_cast<bool*>(result.request().ptr),
194+
std::min(ba.size, bb.size));
195+
return result;
196+
}
197+
198+
// ============================================================================
199+
// Logical — bool → bool
200+
// ============================================================================
201+
202+
inline py::array_t<bool> logical_and(const py::array_t<bool>& a,
203+
const py::array_t<bool>& b) {
204+
auto ba = a.request(), bb = b.request();
205+
py::array_t<bool> result(ba.shape);
206+
numpy::logical_and(static_cast<const bool*>(ba.ptr),
207+
static_cast<const bool*>(bb.ptr),
208+
static_cast<bool*>(result.request().ptr),
209+
std::min(ba.size, bb.size));
210+
return result;
211+
}
212+
213+
inline py::array_t<bool> logical_or(const py::array_t<bool>& a,
214+
const py::array_t<bool>& b) {
215+
auto ba = a.request(), bb = b.request();
216+
py::array_t<bool> result(ba.shape);
217+
numpy::logical_or(static_cast<const bool*>(ba.ptr),
218+
static_cast<const bool*>(bb.ptr),
219+
static_cast<bool*>(result.request().ptr),
220+
std::min(ba.size, bb.size));
221+
return result;
222+
}
223+
224+
inline py::array_t<bool> logical_not(const py::array_t<bool>& a) {
225+
auto buf = a.request();
226+
py::array_t<bool> result(buf.shape);
227+
numpy::logical_not(static_cast<const bool*>(buf.ptr),
228+
static_cast<bool*>(result.request().ptr), buf.size);
229+
return result;
230+
}
231+
232+
inline py::array_t<bool> logical_xor(const py::array_t<bool>& a,
233+
const py::array_t<bool>& b) {
234+
auto ba = a.request(), bb = b.request();
235+
py::array_t<bool> result(ba.shape);
236+
numpy::logical_xor(static_cast<const bool*>(ba.ptr),
237+
static_cast<const bool*>(bb.ptr),
238+
static_cast<bool*>(result.request().ptr),
239+
std::min(ba.size, bb.size));
240+
return result;
241+
}
242+
243+
// ============================================================================
244+
// Special value tests — T → bool
245+
// ============================================================================
246+
#define DEF_SPECIAL(name) \
247+
template<typename T> \
248+
py::array_t<bool> name(const py::array_t<T>& arr) { \
249+
auto buf = arr.request(); \
250+
py::array_t<bool> result(buf.shape); \
251+
numpy::name(static_cast<const T*>(buf.ptr), \
252+
static_cast<bool*>(result.request().ptr), buf.size); \
253+
return result; \
254+
}
255+
DEF_SPECIAL(isnan)
256+
DEF_SPECIAL(isinf)
257+
DEF_SPECIAL(isfinite)
258+
#undef DEF_SPECIAL
259+
260+
// ============================================================================
261+
// Type conversion
262+
// ============================================================================
263+
264+
/// ndarray.astype(dtype) — unified dtype dispatch
265+
inline py::array astype(const py::array& arr, const std::string& dtype) {
266+
auto buf = arr.request();
267+
auto dt = arr.dtype();
268+
269+
#define _ASTYPE_CASE(SrcT, dst_str, DstT) \
270+
if (dt.is(py::dtype::of<SrcT>()) && (dtype == dst_str)) { \
271+
py::array_t<DstT> r(buf.shape); \
272+
numpy::astype<DstT, SrcT>(static_cast<const SrcT*>(buf.ptr), \
273+
static_cast<DstT*>(r.request().ptr), buf.size); \
274+
return r; \
275+
}
276+
277+
// float64
278+
_ASTYPE_CASE(double, "float32", float)
279+
_ASTYPE_CASE(double, "float", float)
280+
_ASTYPE_CASE(double, "int", int)
281+
_ASTYPE_CASE(double, "int32", int)
282+
_ASTYPE_CASE(double, "int64", int64_t)
283+
_ASTYPE_CASE(double, "bool", bool)
284+
// float32
285+
_ASTYPE_CASE(float, "float64", double)
286+
_ASTYPE_CASE(float, "double", double)
287+
_ASTYPE_CASE(float, "int", int)
288+
_ASTYPE_CASE(float, "int32", int)
289+
_ASTYPE_CASE(float, "int64", int64_t)
290+
_ASTYPE_CASE(float, "bool", bool)
291+
// int32
292+
_ASTYPE_CASE(int, "float64", double)
293+
_ASTYPE_CASE(int, "double", double)
294+
_ASTYPE_CASE(int, "float32", float)
295+
_ASTYPE_CASE(int, "float", float)
296+
_ASTYPE_CASE(int, "int64", int64_t)
297+
_ASTYPE_CASE(int, "bool", bool)
298+
// int64
299+
_ASTYPE_CASE(int64_t, "float64", double)
300+
_ASTYPE_CASE(int64_t, "double", double)
301+
_ASTYPE_CASE(int64_t, "float32", float)
302+
_ASTYPE_CASE(int64_t, "float", float)
303+
_ASTYPE_CASE(int64_t, "int", int)
304+
_ASTYPE_CASE(int64_t, "int32", int)
305+
_ASTYPE_CASE(int64_t, "bool", bool)
306+
// bool
307+
_ASTYPE_CASE(bool, "float64", double)
308+
_ASTYPE_CASE(bool, "double", double)
309+
_ASTYPE_CASE(bool, "float32", float)
310+
_ASTYPE_CASE(bool, "float", float)
311+
_ASTYPE_CASE(bool, "int", int)
312+
_ASTYPE_CASE(bool, "int32", int)
313+
_ASTYPE_CASE(bool, "int64", int64_t)
314+
#undef _ASTYPE_CASE
315+
316+
throw std::runtime_error("astype: unsupported conversion " +
317+
std::string(py::str(dt)) + " -> " + dtype);
318+
}
319+
320+
/// float64 → float32 → float64 roundtrip
321+
inline py::array_t<double> truncate_to_float32(const py::array_t<double>& arr) {
322+
auto buf = arr.request();
323+
py::array_t<double> result(buf.shape);
324+
numpy::truncate_to_float32(static_cast<const double*>(buf.ptr),
325+
static_cast<double*>(result.mutable_data()),
326+
buf.size);
327+
return result;
328+
}
329+
330+
} // namespace numpy

0 commit comments

Comments
 (0)