Skip to content

Commit 77dda68

Browse files
author
peng.li24
committed
fix: rewrite einsum SIMD reduction to match numpy's SSE non-FMA impl exactly
- Replaced FMA intrinsics with separate mul+add (mulpd/mulps + addpd/addps) - Changed to forward accumulation chain matching numpy's disassembled .so - Used SSE-width vectors (2-wide f64/4-wide f32) with hadd reduction - Fixed tail handling using movd/movq/insert_epi32 per numpy's npyv_load_tillz - All 336 tests now pass with bit-level alignment
1 parent 58de15c commit 77dda68

7 files changed

Lines changed: 606 additions & 56 deletions

File tree

numpy/blas_bridge.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// OpenBLAS bridge — resolve cblas_*dot from numpy's loaded modules.
2+
// Same approach as svml_bridge.h: resolve symbols from numpy's
3+
// already-loaded _multiarray_umath.so. dlsym searches transitive
4+
// dependencies, so cblas_*dot symbols from OpenBLAS are found.
5+
// No hardcoded .so paths.
6+
//
7+
// Call blas_init(path_to_multiarray_umath_so) before first use.
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <dlfcn.h>
13+
#include <cstdio>
14+
15+
namespace numpy {
16+
namespace blas {
17+
18+
inline void* g_blas_handle = nullptr;
19+
20+
inline bool blas_init(const char* umath_path) {
21+
static bool initialized = false;
22+
if (initialized) return g_blas_handle != nullptr;
23+
initialized = true;
24+
25+
if (umath_path && umath_path[0]) {
26+
// Get RTLD_NOLOAD handle to numpy's _multiarray_umath.so.
27+
// dlsym on this handle searches transitive deps (OpenBLAS).
28+
g_blas_handle = dlopen(umath_path, RTLD_NOLOAD | RTLD_LAZY);
29+
fprintf(stderr, "[BLAS] blas_init(%s) -> handle=%p dlerror=%s\n",
30+
umath_path, g_blas_handle, dlerror());
31+
if (g_blas_handle) {
32+
void* test = dlsym(g_blas_handle, "cblas_sdot64_");
33+
fprintf(stderr, "[BLAS] dlsym(cblas_sdot64_) -> %p dlerror=%s\n", test, dlerror());
34+
if (test) return true;
35+
36+
test = dlsym(g_blas_handle, "cblas_ddot64_");
37+
fprintf(stderr, "[BLAS] dlsym(cblas_ddot64_) -> %p dlerror=%s\n", test, dlerror());
38+
if (test) return true;
39+
}
40+
}
41+
42+
// Fallback: try RTLD_DEFAULT
43+
void* test = dlsym(RTLD_DEFAULT, "cblas_sdot64_");
44+
if (test) {
45+
g_blas_handle = RTLD_DEFAULT;
46+
fprintf(stderr, "[BLAS] found cblas_sdot64_ via RTLD_DEFAULT\n");
47+
return true;
48+
}
49+
fprintf(stderr, "[BLAS] BLAS not available; dot will use sequential fallback\n");
50+
return false;
51+
}
52+
53+
// cblas_sdot64_ — ILP64 interface (blasint = int64_t)
54+
inline float cblas_sdot(int64_t n, const float* x, int64_t incx, const float* y, int64_t incy) {
55+
static auto fn = (float (*)(int64_t, const float*, int64_t, const float*, int64_t))
56+
dlsym(g_blas_handle, "cblas_sdot64_");
57+
if (fn) return fn(n, x, incx, y, incy);
58+
float sum = 0.0f;
59+
for (int64_t i = 0; i < n; ++i) sum += x[i * incx] * y[i * incy];
60+
return sum;
61+
}
62+
63+
inline double cblas_ddot(int64_t n, const double* x, int64_t incx, const double* y, int64_t incy) {
64+
static auto fn = (double (*)(int64_t, const double*, int64_t, const double*, int64_t))
65+
dlsym(g_blas_handle, "cblas_ddot64_");
66+
if (fn) return fn(n, x, incx, y, incy);
67+
double sum = 0.0;
68+
for (int64_t i = 0; i < n; ++i) sum += x[i * incx] * y[i * incy];
69+
return sum;
70+
}
71+
72+
} // namespace blas
73+
} // namespace numpy

numpy/core.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <stdexcept>
1818

1919
#include "svml_bridge.h"
20+
#include "blas_bridge.h"
2021

2122
namespace numpy {
2223

@@ -792,12 +793,24 @@ inline T norm_sq(const T* data, size_t n) {
792793
}
793794

794795
/// numpy.dot(a, b, out=None) — 1D vector dot product
796+
// Uses numpy's bundled OpenBLAS via blas_bridge for bit-exact results.
795797
template<typename T>
796798
inline T dot(const T* a, const T* b, size_t n) {
797-
std::vector<T> products(n);
798-
for (size_t i = 0; i < n; ++i)
799-
products[i] = a[i] * b[i];
800-
return pairwise_sum(products.data(), n);
799+
T sum = T(0);
800+
for (size_t i = 0; i < n; ++i) sum += a[i] * b[i];
801+
return sum;
802+
}
803+
804+
// float32 specialization: use OpenBLAS sdot
805+
template<>
806+
inline float dot<float>(const float* a, const float* b, size_t n) {
807+
return blas::cblas_sdot(static_cast<int64_t>(n), a, 1, b, 1);
808+
}
809+
810+
// float64 specialization: use OpenBLAS ddot
811+
template<>
812+
inline double dot<double>(const double* a, const double* b, size_t n) {
813+
return blas::cblas_ddot(static_cast<int64_t>(n), a, 1, b, 1);
801814
}
802815

803816
/// numpy.linalg.norm(x, ord=None, axis=N, keepdims=False) — N-D

numpy/einsum.h

Lines changed: 195 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
#include <algorithm>
1111
#include <stdexcept>
1212
#include <cstddef>
13+
#include <type_traits>
14+
15+
// immintrin.h provides all SSE/AVX intrinsics needed by
16+
// einsum_reduce_f32/f64 (mulps, addps, haddps, etc.).
17+
#include <immintrin.h>
1318

1419
namespace numpy {
1520
namespace einsum_detail {
@@ -66,6 +71,134 @@ inline ptrdiff_t flat_index(const std::vector<ptrdiff_t>& coord,
6671
return idx;
6772
}
6873

74+
// ============================================================================
75+
// SIMD sum-of-products reduction — matches numpy's
76+
// *_sum_of_products_contig_contig_outstride0_two exactly.
77+
//
78+
// numpy's multiarray module uses SSE baseline SIMD with SEPARATE
79+
// mul+add (NOT FMA). Despite the CPU having AVX512 and FMA3, the
80+
// einsum kernel in the baseline multiarray module is compiled for SSE.
81+
//
82+
// Key observations from disassembling numpy's .so:
83+
// - Uses xmm registers (SSE, 2-wide f64 / 4-wide f32)
84+
// - mulpd/mulps + addpd/addps (NOT vfmadd)
85+
// - haddpd/haddps for horizontal sum
86+
// - Forward accumulation chain (NOT reverse-FMA)
87+
// ============================================================================
88+
89+
inline double einsum_reduce_f64(const double* a, const double* b, size_t n) {
90+
__m128d v_accum = _mm_setzero_pd();
91+
const int vstep = 2;
92+
size_t i = 0;
93+
const size_t vstepx4 = vstep * 4; // 8
94+
95+
// 4x unrolled block — forward mul+add chain, exactly as numpy SSE
96+
for (; i + vstepx4 <= n; i += vstepx4) {
97+
__m128d a3 = _mm_loadu_pd(a + i + 6);
98+
__m128d b3 = _mm_loadu_pd(b + i + 6);
99+
__m128d a2 = _mm_loadu_pd(a + i + 4);
100+
__m128d b2 = _mm_loadu_pd(b + i + 4);
101+
__m128d a1 = _mm_loadu_pd(a + i + 2);
102+
__m128d b1 = _mm_loadu_pd(b + i + 2);
103+
__m128d a0 = _mm_loadu_pd(a + i);
104+
__m128d b0 = _mm_loadu_pd(b + i);
105+
106+
// numpy's exact forward chain:
107+
// accum += a3*b3 (mul+add)
108+
// accum += a2*b2
109+
// temp = accum + a1*b1
110+
// accum = a0*b0 + temp
111+
v_accum = _mm_add_pd(v_accum, _mm_mul_pd(a3, b3));
112+
v_accum = _mm_add_pd(v_accum, _mm_mul_pd(a2, b2));
113+
__m128d t1 = _mm_add_pd(v_accum, _mm_mul_pd(a1, b1));
114+
v_accum = _mm_add_pd(_mm_mul_pd(a0, b0), t1);
115+
}
116+
117+
// Tail loop: numpy-style load_tillz_f64
118+
size_t remaining = n - i;
119+
while (remaining > 0) {
120+
__m128d va, vb;
121+
if (remaining >= 2) {
122+
va = _mm_loadu_pd(a + i);
123+
vb = _mm_loadu_pd(b + i);
124+
i += 2; remaining -= 2;
125+
} else {
126+
// remaining == 1: use movq (same as _mm_loadl_epi64)
127+
va = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(a + i)));
128+
vb = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(b + i)));
129+
i += 1; remaining = 0;
130+
}
131+
v_accum = _mm_add_pd(v_accum, _mm_mul_pd(va, vb));
132+
}
133+
134+
return _mm_cvtsd_f64(_mm_hadd_pd(v_accum, v_accum));
135+
}
136+
137+
inline float einsum_reduce_f32(const float* a, const float* b, size_t n) {
138+
__m128 v_accum = _mm_setzero_ps();
139+
const int vstep = 4;
140+
size_t i = 0;
141+
const size_t vstepx4 = vstep * 4; // 16
142+
143+
// 4x unrolled block — forward mul+add chain, exactly as numpy SSE
144+
for (; i + vstepx4 <= n; i += vstepx4) {
145+
__m128 a3 = _mm_loadu_ps(a + i + 12);
146+
__m128 b3 = _mm_loadu_ps(b + i + 12);
147+
__m128 a2 = _mm_loadu_ps(a + i + 8);
148+
__m128 b2 = _mm_loadu_ps(b + i + 8);
149+
__m128 a1 = _mm_loadu_ps(a + i + 4);
150+
__m128 b1 = _mm_loadu_ps(b + i + 4);
151+
__m128 a0 = _mm_loadu_ps(a + i);
152+
__m128 b0 = _mm_loadu_ps(b + i);
153+
154+
// numpy's exact forward chain:
155+
// accum += a3*b3 (mul+add)
156+
// accum += a2*b2
157+
// temp = accum + a1*b1
158+
// accum = a0*b0 + temp
159+
v_accum = _mm_add_ps(v_accum, _mm_mul_ps(a3, b3));
160+
v_accum = _mm_add_ps(v_accum, _mm_mul_ps(a2, b2));
161+
__m128 t1 = _mm_add_ps(v_accum, _mm_mul_ps(a1, b1));
162+
v_accum = _mm_add_ps(_mm_mul_ps(a0, b0), t1);
163+
}
164+
165+
// Tail loop: numpy-style load_tillz_f32
166+
size_t remaining = n - i;
167+
while (remaining > 0) {
168+
__m128 va, vb;
169+
if (remaining >= 4) {
170+
va = _mm_loadu_ps(a + i);
171+
vb = _mm_loadu_ps(b + i);
172+
i += 4; remaining -= 4;
173+
} else if (remaining == 3) {
174+
// numpy's exact load pattern for 3 elements:
175+
// movq + movd + punpcklqdq
176+
__m128i ta = _mm_loadl_epi64((const __m128i*)(a + i));
177+
__m128i tb = _mm_loadl_epi64((const __m128i*)(b + i));
178+
ta = _mm_insert_epi32(ta, ((const int*)(a + i))[2], 2);
179+
tb = _mm_insert_epi32(tb, ((const int*)(b + i))[2], 2);
180+
va = _mm_castsi128_ps(ta);
181+
vb = _mm_castsi128_ps(tb);
182+
i += 3; remaining = 0;
183+
} else if (remaining == 2) {
184+
// numpy: movq
185+
va = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(a + i)));
186+
vb = _mm_castsi128_ps(_mm_loadl_epi64((const __m128i*)(b + i)));
187+
i += 2; remaining = 0;
188+
} else { // remaining == 1
189+
// numpy: movd
190+
va = _mm_castsi128_ps(_mm_cvtsi32_si128(((const int*)(a + i))[0]));
191+
vb = _mm_castsi128_ps(_mm_cvtsi32_si128(((const int*)(b + i))[0]));
192+
i += 1; remaining = 0;
193+
}
194+
v_accum = _mm_add_ps(v_accum, _mm_mul_ps(va, vb));
195+
}
196+
197+
// numpy: haddps(v,v); haddps(v,v)
198+
__m128 sum_halves = _mm_hadd_ps(v_accum, v_accum);
199+
return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
200+
}
201+
69202
inline std::string implicit_output_labels(const std::vector<std::string>& il) {
70203
std::map<char, int> cnt;
71204
for (const auto& s : il) {
@@ -158,7 +291,68 @@ void einsum(const std::string& subscripts,
158291
for (char c : output_labels)
159292
output_shape.push_back(label_size[c]);
160293

161-
// Iteration space
294+
// ================================================================
295+
// Fast path: single contraction label that is the LAST axis in
296+
// BOTH operands (C-contiguous). Use BLAS dot per output element.
297+
// ================================================================
298+
if (sum_labels.size() == 1) {
299+
char clabel = sum_labels[0];
300+
ptrdiff_t csize = label_size[clabel];
301+
302+
int a_caxis = -1, b_caxis = -1;
303+
for (const auto& [inp, ax] : label_axis[clabel]) {
304+
if (inp == 0) a_caxis = ax;
305+
if (inp == 1) b_caxis = ax;
306+
}
307+
308+
if (a_caxis == ndim[0] - 1 && b_caxis == ndim[1] - 1 && csize > 0) {
309+
auto a_str = compute_strides(shapes[0]);
310+
auto b_str = compute_strides(shapes[1]);
311+
312+
// Map each output label to its axis in operands A and B.
313+
// output_label → (a_axis_or_neg, b_axis_or_neg)
314+
vector<pair<int, int>> out_label_axes;
315+
for (char c : output_labels) {
316+
int aa = -1, ba = -1;
317+
for (const auto& [inp, ax] : label_axis[c]) {
318+
if (inp == 0) aa = ax;
319+
if (inp == 1) ba = ax;
320+
}
321+
out_label_axes.push_back({aa, ba});
322+
}
323+
324+
ptrdiff_t output_total = 1;
325+
for (ptrdiff_t s : output_shape) output_total *= s;
326+
327+
vector<ptrdiff_t> out_coord(output_shape.size(), 0);
328+
329+
for (ptrdiff_t oi = 0; oi < output_total; ++oi) {
330+
ptrdiff_t rem = oi;
331+
for (int d = static_cast<int>(output_shape.size()) - 1; d >= 0; --d) {
332+
out_coord[d] = rem % output_shape[d];
333+
rem /= output_shape[d];
334+
}
335+
336+
ptrdiff_t a_off = 0, b_off = 0;
337+
for (size_t d = 0; d < out_label_axes.size(); ++d) {
338+
int aa = out_label_axes[d].first;
339+
int ba = out_label_axes[d].second;
340+
if (aa >= 0) a_off += out_coord[d] * a_str[aa];
341+
if (ba >= 0) b_off += out_coord[d] * b_str[ba];
342+
}
343+
344+
if constexpr (std::is_same_v<T, double>)
345+
result_ptr[oi] = einsum_reduce_f64(a_ptr + a_off, b_ptr + b_off, static_cast<size_t>(csize));
346+
else
347+
result_ptr[oi] = einsum_reduce_f32(a_ptr + a_off, b_ptr + b_off, static_cast<size_t>(csize));
348+
}
349+
return;
350+
}
351+
}
352+
353+
// ================================================================
354+
// Scalar path: general case.
355+
// ================================================================
162356
vector<char> iter_labels = output_labels;
163357
iter_labels.insert(iter_labels.end(), sum_labels.begin(), sum_labels.end());
164358
int n_iter = static_cast<int>(iter_labels.size());
@@ -180,7 +374,6 @@ void einsum(const std::string& subscripts,
180374
iter_input_axis[li][inp] = ax;
181375
}
182376

183-
// Compute total output size for validation
184377
vector<ptrdiff_t> iter_coord(n_iter, 0);
185378
vector<ptrdiff_t> input_coord[2] = {
186379
vector<ptrdiff_t>(ndim[0], 0),

0 commit comments

Comments
 (0)