Skip to content

Commit 69a6689

Browse files
author
peng.li24
committed
feat: add N-D concatenate with axis support — bit-exact with numpy
Add axis-aware N-D concatenate to the native core (core.h), updating the pybind11 wrapper (core_py.h) and bindings (module.cpp) to accept an axis parameter. Native implementation uses leading-slice block copies — each slice contributes contiguous elements per array, so a single memcpy per array per slice suffices. Per-array strides correctly account for differing axis dimension sizes. Also improve vstack/hstack wrappers: - vstack: 1D arrays reshaped to (1,N) before stacking - hstack: uses axis=1 for 2D+ arrays Adds 30+ new concatenate tests covering 1D–5D, all axes, float32/64, large arrays, identity, zeros/ones, and edge cases.
1 parent b4dd568 commit 69a6689

4 files changed

Lines changed: 296 additions & 15 deletions

File tree

numpy/core.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ inline void stack(const T* const* arrays, T* dst, size_t n_arrays, size_t elem_s
530530
}
531531

532532
/// numpy.concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting=...)
533+
/// 1D flat: arrays are treated as flat buffers, concatenated sequentially.
533534
template<typename T>
534535
inline void concatenate(const T* const* arrays, T* dst, const size_t* sizes, size_t n_arrays) {
535536
size_t off = 0;
@@ -539,6 +540,97 @@ inline void concatenate(const T* const* arrays, T* dst, const size_t* sizes, siz
539540
}
540541
}
541542

543+
/// numpy.concatenate((a1, a2, ...), axis=0, ...) — N-D with axis support.
544+
/// All arrays must have identical shape except along `axis`.
545+
/// `shape` is the representative common shape (use first array's shape);
546+
/// `axis_sizes[i]` gives the size of array i along the concatenation axis.
547+
///
548+
/// Strategy: iterate over "leading slices" (product of dims before axis).
549+
/// Within each slice, every array contributes a contiguous block of
550+
/// `axis_sizes[i] * trailing` elements. Since the elements are C-contiguous
551+
/// within each slice, a single memcpy per array per slice suffices.
552+
template<typename T>
553+
inline void concatenate(const T* const* arrays, T* dst,
554+
const ptrdiff_t* shape, int ndim, int axis,
555+
const size_t* axis_sizes, size_t n_arrays) {
556+
if (n_arrays == 0 || ndim == 0) return;
557+
558+
// Normalize axis
559+
if (axis < 0) axis += ndim;
560+
561+
// Trailing product = product of dims after axis (also = stride along axis)
562+
ptrdiff_t trailing = 1;
563+
for (int d = axis + 1; d < ndim; ++d) trailing *= shape[d];
564+
565+
// Total output axis size
566+
ptrdiff_t out_axis = 0;
567+
for (size_t i = 0; i < n_arrays; ++i)
568+
out_axis += static_cast<ptrdiff_t>(axis_sizes[i]);
569+
570+
// Per-array full strides (differ because axis dim sizes differ)
571+
// C-contiguous: stride[d] = stride[d+1] * size_of_dim[d+1].
572+
// Use axis_sizes[k] when d+1 == axis, common shape otherwise.
573+
std::vector<std::vector<ptrdiff_t>> in_stride(n_arrays);
574+
for (size_t k = 0; k < n_arrays; ++k) {
575+
in_stride[k].resize(ndim);
576+
in_stride[k][ndim - 1] = 1;
577+
for (int d = ndim - 2; d >= 0; --d) {
578+
ptrdiff_t s = (d + 1 == axis)
579+
? static_cast<ptrdiff_t>(axis_sizes[k])
580+
: shape[d + 1];
581+
in_stride[k][d] = in_stride[k][d + 1] * s;
582+
}
583+
}
584+
585+
// Output strides
586+
std::vector<ptrdiff_t> out_shape(shape, shape + ndim);
587+
out_shape[axis] = out_axis;
588+
std::vector<ptrdiff_t> out_stride(ndim);
589+
out_stride[ndim - 1] = 1;
590+
for (int d = ndim - 2; d >= 0; --d)
591+
out_stride[d] = out_stride[d + 1] * out_shape[d + 1];
592+
593+
// Per-array per-slice element count (contiguous elements contributed per slice)
594+
std::vector<size_t> slice_n(n_arrays);
595+
for (size_t i = 0; i < n_arrays; ++i)
596+
slice_n[i] = static_cast<size_t>(axis_sizes[i]) * static_cast<size_t>(trailing);
597+
598+
// Number of leading slices
599+
ptrdiff_t n_slices = 1;
600+
for (int d = 0; d < axis; ++d) n_slices *= shape[d];
601+
602+
// Total per-array byte size of one slice (for output position stepping)
603+
ptrdiff_t out_slice_bytes = static_cast<ptrdiff_t>(
604+
static_cast<size_t>(out_axis) * static_cast<size_t>(trailing) * sizeof(T));
605+
606+
// For each leading slice, copy contiguous blocks from each array
607+
for (ptrdiff_t s = 0; s < n_slices; ++s) {
608+
// Decompose slice index → multi-index for dims 0..axis-1
609+
ptrdiff_t rem = s;
610+
611+
// Per-array leading offset within the array
612+
std::vector<size_t> in_off(n_arrays, 0);
613+
614+
for (int d = axis - 1; d >= 0; --d) {
615+
ptrdiff_t idx = rem % shape[d];
616+
rem /= shape[d];
617+
for (size_t k = 0; k < n_arrays; ++k)
618+
in_off[k] += static_cast<size_t>(idx) * static_cast<size_t>(in_stride[k][d]);
619+
}
620+
621+
// Output position for this slice
622+
char* out_slice_start = reinterpret_cast<char*>(dst) + s * out_slice_bytes;
623+
size_t out_byte_off = 0;
624+
625+
for (size_t i = 0; i < n_arrays; ++i) {
626+
size_t bytes = slice_n[i] * sizeof(T);
627+
std::memcpy(out_slice_start + out_byte_off,
628+
arrays[i] + in_off[i], bytes);
629+
out_byte_off += bytes;
630+
}
631+
}
632+
}
633+
542634
/// numpy.where(condition, x, y) — scalar x, y
543635
template<typename T>
544636
inline void where_scalar(const bool* cond, T* dst, size_t n, T x, T y) {

pycpp/core_py.h

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -884,28 +884,93 @@ py::array_t<T> stack(const std::vector<py::array_t<T>>& arrays) {
884884

885885
/// numpy.concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting=...)
886886
template<typename T>
887-
py::array_t<T> concatenate(const std::vector<py::array_t<T>>& arrays) {
887+
py::array_t<T> concatenate(const std::vector<py::array_t<T>>& arrays, int axis = 0) {
888888
if (arrays.empty()) return py::array_t<T>{};
889-
py::ssize_t total = 0;
890-
for (const auto& arr : arrays) total += arr.request().size;
891-
py::array_t<T> result({total});
892-
T* dst = static_cast<T*>(result.request().ptr);
893-
py::ssize_t off = 0;
889+
890+
auto buf0 = arrays[0].request();
891+
int ndim = static_cast<int>(buf0.ndim);
892+
893+
if (axis < 0) axis += ndim;
894+
if (axis < 0 || axis >= ndim)
895+
throw std::invalid_argument("concatenate: axis out of range");
896+
897+
// Validate that all arrays have same number of dimensions
894898
for (const auto& arr : arrays) {
895-
auto buf = arr.request();
896-
std::memcpy(dst + off, static_cast<const T*>(buf.ptr), buf.size * sizeof(T));
897-
off += buf.size;
899+
if (arr.request().ndim != ndim)
900+
throw std::invalid_argument("concatenate: all arrays must have same number of dimensions");
901+
}
902+
903+
// Collect shape (from first array) and per-array axis sizes
904+
std::vector<ptrdiff_t> shape(ndim);
905+
for (int d = 0; d < ndim; ++d) shape[d] = buf0.shape[d];
906+
907+
std::vector<size_t> axis_sizes(arrays.size());
908+
for (size_t i = 0; i < arrays.size(); ++i) {
909+
auto buf = arrays[i].request();
910+
axis_sizes[i] = static_cast<size_t>(buf.shape[axis]);
911+
}
912+
913+
// Validate non-axis dimensions match
914+
for (size_t i = 0; i < arrays.size(); ++i) {
915+
auto buf = arrays[i].request();
916+
for (int d = 0; d < ndim; ++d) {
917+
if (d == axis) continue;
918+
if (buf.shape[d] != shape[d])
919+
throw std::invalid_argument(
920+
"concatenate: all arrays must have same shape except along axis");
921+
}
898922
}
923+
924+
// Compute output shape
925+
std::vector<ptrdiff_t> out_shape = shape;
926+
ptrdiff_t total_axis = 0;
927+
for (size_t i = 0; i < arrays.size(); ++i)
928+
total_axis += static_cast<ptrdiff_t>(axis_sizes[i]);
929+
out_shape[axis] = total_axis;
930+
931+
std::vector<py::ssize_t> py_out_shape(out_shape.begin(), out_shape.end());
932+
py::array_t<T> result(py_out_shape);
933+
T* dst = static_cast<T*>(result.request().ptr);
934+
935+
// Build pointer array
936+
std::vector<const T*> ptrs(arrays.size());
937+
for (size_t i = 0; i < arrays.size(); ++i)
938+
ptrs[i] = static_cast<const T*>(arrays[i].request().ptr);
939+
940+
numpy::concatenate(ptrs.data(), dst, shape.data(), ndim, axis,
941+
axis_sizes.data(), arrays.size());
899942
return result;
900943
}
901944

902945
/// numpy.vstack(tup, *, dtype=None, casting=...)
903946
template<typename T>
904-
py::array_t<T> vstack(const std::vector<py::array_t<T>>& arrays) { return stack(arrays); }
947+
py::array_t<T> vstack(const std::vector<py::array_t<T>>& arrays) {
948+
if (arrays.empty()) return py::array_t<T>{};
949+
int ndim = static_cast<int>(arrays[0].request().ndim);
950+
if (ndim == 1) {
951+
// numpy.vstack: 1D arrays are reshaped to (1, N) before stacking
952+
auto buf0 = arrays[0].request();
953+
py::array_t<T> result({static_cast<py::ssize_t>(arrays.size()), static_cast<py::ssize_t>(buf0.size)});
954+
T* dst = static_cast<T*>(result.request().ptr);
955+
for (size_t i = 0; i < arrays.size(); ++i) {
956+
auto buf = arrays[i].request();
957+
std::memcpy(dst + i * buf0.size, static_cast<const T*>(buf.ptr),
958+
buf.size * sizeof(T));
959+
}
960+
return result;
961+
}
962+
return concatenate(arrays, 0);
963+
}
905964

906965
/// numpy.hstack(tup, *, dtype=None, casting=...)
907966
template<typename T>
908-
py::array_t<T> hstack(const std::vector<py::array_t<T>>& arrays) { return concatenate(arrays); }
967+
py::array_t<T> hstack(const std::vector<py::array_t<T>>& arrays) {
968+
if (arrays.empty()) return py::array_t<T>{};
969+
int ndim = static_cast<int>(arrays[0].request().ndim);
970+
// 1D arrays: hstack is identical to concatenate along axis=0
971+
// 2D+ arrays: hstack concatenates along axis=1
972+
return concatenate(arrays, (ndim == 1) ? 0 : 1);
973+
}
909974

910975
/// numpy.where(condition, x, y) — scalar x, y
911976
template<typename T>

tests/module.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,12 @@ PYBIND11_MODULE(numpycpp, m) {
196196
py::arg("arr"), py::arg("n") = 1, py::arg("axis") = -1);
197197
m.def("diff", static_cast<py::array_t<float>(*)(const py::array_t<float>&, int, int)>(&numpy::diff),
198198
py::arg("arr"), py::arg("n") = 1, py::arg("axis") = -1);
199-
BIND_F_STACK(stack); BIND_F_STACK(concatenate); BIND_F_STACK(vstack); BIND_F_STACK(hstack);
199+
BIND_F_STACK(stack); BIND_F_STACK(vstack); BIND_F_STACK(hstack);
200+
201+
m.def("concatenate", static_cast<py::array_t<double>(*)(const std::vector<py::array_t<double>>&, int)>(&numpy::concatenate),
202+
py::arg("arrays"), py::arg("axis") = 0);
203+
m.def("concatenate", static_cast<py::array_t<float>(*)(const std::vector<py::array_t<float>>&, int)>(&numpy::concatenate),
204+
py::arg("arrays"), py::arg("axis") = 0);
200205

201206
m.def("where", static_cast<py::array_t<double>(*)(const py::array_t<bool>&, double, double)>(&numpy::where));
202207
m.def("where", static_cast<py::array_t<float>(*)(const py::array_t<bool>&, float, float)>(&numpy::where));

tests/test_all.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,17 +678,136 @@ def test_stack(cpp, dtype):
678678
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(4)]
679679
assert_bit_aligned(cpp.stack(arrays), np.stack(arrays), "stack")
680680

681-
def test_concatenate(cpp, dtype):
681+
def test_concatenate_1d(cpp, dtype):
682682
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(3)]
683-
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate")
683+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate 1d")
684+
685+
def test_concatenate_2d_axis0(cpp, dtype):
686+
arrays = [random_array((2, 3), seed=i, dtype=dtype) for i in range(3)]
687+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 2d axis=0")
688+
# Verify default axis=0
689+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate 2d default axis")
690+
691+
def test_concatenate_2d_axis1(cpp, dtype):
692+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
693+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concatenate 2d axis=1")
694+
695+
def test_concatenate_2d_axis_neg1(cpp, dtype):
696+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
697+
assert_bit_aligned(cpp.concatenate(arrays, -1), np.concatenate(arrays, axis=-1), "concatenate 2d axis=-1")
698+
699+
def test_concatenate_3d_axis0(cpp, dtype):
700+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(2)]
701+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 3d axis=0")
702+
703+
def test_concatenate_3d_axis1(cpp, dtype):
704+
arrays = [random_array((3, 2, 4), seed=i, dtype=dtype) for i in range(2)]
705+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concatenate 3d axis=1")
706+
707+
def test_concatenate_3d_axis2(cpp, dtype):
708+
arrays = [random_array((3, 4, 2), seed=i, dtype=dtype) for i in range(2)]
709+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concatenate 3d axis=2")
710+
711+
def test_concatenate_two_arrays(cpp, dtype):
712+
arrays = [random_array((5,), seed=0, dtype=dtype), random_array((7,), seed=1, dtype=dtype)]
713+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate two")
714+
715+
def test_concatenate_single(cpp, dtype):
716+
arrays = [random_array((5,), dtype=dtype)]
717+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate single")
684718

685719
def test_vstack(cpp, dtype):
686720
arrays = [random_array((1, 3), seed=i, dtype=dtype) for i in range(4)]
687721
assert_bit_aligned(cpp.vstack(arrays), np.vstack(arrays), "vstack")
688722

723+
def test_vstack_1d(cpp, dtype):
724+
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(4)]
725+
assert_bit_aligned(cpp.vstack(arrays), np.vstack(arrays), "vstack 1d")
726+
689727
def test_hstack(cpp, dtype):
690728
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(3)]
691-
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack")
729+
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack 1d")
730+
731+
def test_hstack_2d(cpp, dtype):
732+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
733+
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack 2d")
734+
735+
# -- Concatenate complex / edge-case tests ----------------------------------
736+
737+
def test_concatenate_4d_axis0(cpp, dtype):
738+
arrays = [random_array((2, 3, 4, 5), seed=i, dtype=dtype) for i in range(2)]
739+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 4d axis=0")
740+
741+
def test_concatenate_4d_axis2(cpp, dtype):
742+
arrays = [random_array((2, 3, 2, 5), seed=i, dtype=dtype) for i in range(2)]
743+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concatenate 4d axis=2")
744+
745+
def test_concatenate_4d_axis_neg2(cpp, dtype):
746+
arrays = [random_array((2, 3, 2, 5), seed=i, dtype=dtype) for i in range(2)]
747+
assert_bit_aligned(cpp.concatenate(arrays, -2), np.concatenate(arrays, axis=-2), "concatenate 4d axis=-2")
748+
749+
def test_concatenate_unequal_axis_sizes(cpp, dtype):
750+
"""Concatenate arrays of different sizes along the concatenation axis."""
751+
a = random_array((3, 2), seed=1, dtype=dtype)
752+
b = random_array((3, 4), seed=2, dtype=dtype)
753+
c = random_array((3, 1), seed=3, dtype=dtype)
754+
assert_bit_aligned(cpp.concatenate([a, b, c], 1),
755+
np.concatenate([a, b, c], axis=1), "concat unequal axis sizes")
756+
757+
def test_concatenate_many_arrays(cpp, dtype):
758+
"""Concatenate 10 arrays along axis=0."""
759+
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(10)]
760+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concat 10 arrays")
761+
762+
def test_concatenate_large_3d(cpp, dtype):
763+
"""Large 3D concatenation along middle axis."""
764+
arrays = [random_array((50, 20, 30), seed=i, dtype=dtype) for i in range(3)]
765+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concat large 3d axis=1")
766+
767+
def test_concatenate_large_2d_axis0(cpp, dtype):
768+
"""Large 2D concatenation — 500 rows each, 4 arrays."""
769+
arrays = [random_array((500, 10), seed=i, dtype=dtype) for i in range(4)]
770+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concat large 2d axis=0")
771+
772+
def test_concatenate_large_2d_axis1(cpp, dtype):
773+
"""Large 2D concatenation — 500 cols each, 3 arrays."""
774+
arrays = [random_array((10, 500), seed=i, dtype=dtype) for i in range(3)]
775+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concat large 2d axis=1")
776+
777+
def test_concatenate_identity(cpp, dtype):
778+
"""Concatenating a single array returns identical copy."""
779+
a = random_array((3, 4), seed=42, dtype=dtype)
780+
assert_bit_aligned(cpp.concatenate([a], 0), np.concatenate([a], axis=0), "concat identity")
781+
assert_bit_aligned(cpp.concatenate([a], 1), np.concatenate([a], axis=1), "concat identity axis=1")
782+
783+
def test_concatenate_zeros(cpp, dtype):
784+
"""Concatenate arrays of zeros."""
785+
a = np.zeros((2, 3), dtype=dtype)
786+
b = np.zeros((2, 5), dtype=dtype)
787+
assert_bit_aligned(cpp.concatenate([a, b], 1), np.concatenate([a, b], axis=1), "concat zeros")
788+
789+
def test_concatenate_ones(cpp, dtype):
790+
"""Concatenate arrays of ones."""
791+
a = np.ones((3, 2), dtype=dtype)
792+
b = np.ones((5, 2), dtype=dtype)
793+
assert_bit_aligned(cpp.concatenate([a, b], 0), np.concatenate([a, b], axis=0), "concat ones")
794+
795+
def test_concatenate_3d_axis_neg2(cpp, dtype):
796+
"""3D concatenate along axis=-2 (middle axis)."""
797+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(3)]
798+
assert_bit_aligned(cpp.concatenate(arrays, -2), np.concatenate(arrays, axis=-2), "concat 3d axis=-2")
799+
800+
def test_concatenate_3d_axis_neg3(cpp, dtype):
801+
"""3D concatenate along axis=-3 (first axis)."""
802+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(2)]
803+
assert_bit_aligned(cpp.concatenate(arrays, -3), np.concatenate(arrays, axis=-3), "concat 3d axis=-3")
804+
805+
def test_concatenate_5d(cpp, dtype):
806+
"""5D concatenate along various axes."""
807+
arrays = [random_array((2, 3, 2, 3, 2), seed=i, dtype=dtype) for i in range(2)]
808+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concat 5d axis=0")
809+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concat 5d axis=2")
810+
assert_bit_aligned(cpp.concatenate(arrays, -1), np.concatenate(arrays, axis=-1), "concat 5d axis=-1")
692811

693812
def test_where_scalar(cpp, dtype):
694813
cond = np.array([True, False, True, False, True])

0 commit comments

Comments
 (0)