1010#include < algorithm>
1111#include < stdexcept>
1212#include < cstddef>
13+ #include < type_traits>
14+
15+ // immintrin.h provides all SSE/AVX intrinsics needed by
16+ // einsum_reduce_f32/f64 (mulps, addps, haddps, etc.).
17+ #include < immintrin.h>
1318
1419namespace numpy {
1520namespace einsum_detail {
@@ -66,6 +71,134 @@ inline ptrdiff_t flat_index(const std::vector<ptrdiff_t>& coord,
6671 return idx;
6772}
6873
74+ // ============================================================================
75+ // SIMD sum-of-products reduction — matches numpy's
76+ // *_sum_of_products_contig_contig_outstride0_two exactly.
77+ //
78+ // numpy's multiarray module uses SSE baseline SIMD with SEPARATE
79+ // mul+add (NOT FMA). Despite the CPU having AVX512 and FMA3, the
80+ // einsum kernel in the baseline multiarray module is compiled for SSE.
81+ //
82+ // Key observations from disassembling numpy's .so:
83+ // - Uses xmm registers (SSE, 2-wide f64 / 4-wide f32)
84+ // - mulpd/mulps + addpd/addps (NOT vfmadd)
85+ // - haddpd/haddps for horizontal sum
86+ // - Forward accumulation chain (NOT reverse-FMA)
87+ // ============================================================================
88+
89+ inline double einsum_reduce_f64 (const double * a, const double * b, size_t n) {
90+ __m128d v_accum = _mm_setzero_pd ();
91+ const int vstep = 2 ;
92+ size_t i = 0 ;
93+ const size_t vstepx4 = vstep * 4 ; // 8
94+
95+ // 4x unrolled block — forward mul+add chain, exactly as numpy SSE
96+ for (; i + vstepx4 <= n; i += vstepx4) {
97+ __m128d a3 = _mm_loadu_pd (a + i + 6 );
98+ __m128d b3 = _mm_loadu_pd (b + i + 6 );
99+ __m128d a2 = _mm_loadu_pd (a + i + 4 );
100+ __m128d b2 = _mm_loadu_pd (b + i + 4 );
101+ __m128d a1 = _mm_loadu_pd (a + i + 2 );
102+ __m128d b1 = _mm_loadu_pd (b + i + 2 );
103+ __m128d a0 = _mm_loadu_pd (a + i);
104+ __m128d b0 = _mm_loadu_pd (b + i);
105+
106+ // numpy's exact forward chain:
107+ // accum += a3*b3 (mul+add)
108+ // accum += a2*b2
109+ // temp = accum + a1*b1
110+ // accum = a0*b0 + temp
111+ v_accum = _mm_add_pd (v_accum, _mm_mul_pd (a3, b3));
112+ v_accum = _mm_add_pd (v_accum, _mm_mul_pd (a2, b2));
113+ __m128d t1 = _mm_add_pd (v_accum, _mm_mul_pd (a1, b1));
114+ v_accum = _mm_add_pd (_mm_mul_pd (a0, b0), t1);
115+ }
116+
117+ // Tail loop: numpy-style load_tillz_f64
118+ size_t remaining = n - i;
119+ while (remaining > 0 ) {
120+ __m128d va, vb;
121+ if (remaining >= 2 ) {
122+ va = _mm_loadu_pd (a + i);
123+ vb = _mm_loadu_pd (b + i);
124+ i += 2 ; remaining -= 2 ;
125+ } else {
126+ // remaining == 1: use movq (same as _mm_loadl_epi64)
127+ va = _mm_castsi128_pd (_mm_loadl_epi64 ((const __m128i*)(a + i)));
128+ vb = _mm_castsi128_pd (_mm_loadl_epi64 ((const __m128i*)(b + i)));
129+ i += 1 ; remaining = 0 ;
130+ }
131+ v_accum = _mm_add_pd (v_accum, _mm_mul_pd (va, vb));
132+ }
133+
134+ return _mm_cvtsd_f64 (_mm_hadd_pd (v_accum, v_accum));
135+ }
136+
137+ inline float einsum_reduce_f32 (const float * a, const float * b, size_t n) {
138+ __m128 v_accum = _mm_setzero_ps ();
139+ const int vstep = 4 ;
140+ size_t i = 0 ;
141+ const size_t vstepx4 = vstep * 4 ; // 16
142+
143+ // 4x unrolled block — forward mul+add chain, exactly as numpy SSE
144+ for (; i + vstepx4 <= n; i += vstepx4) {
145+ __m128 a3 = _mm_loadu_ps (a + i + 12 );
146+ __m128 b3 = _mm_loadu_ps (b + i + 12 );
147+ __m128 a2 = _mm_loadu_ps (a + i + 8 );
148+ __m128 b2 = _mm_loadu_ps (b + i + 8 );
149+ __m128 a1 = _mm_loadu_ps (a + i + 4 );
150+ __m128 b1 = _mm_loadu_ps (b + i + 4 );
151+ __m128 a0 = _mm_loadu_ps (a + i);
152+ __m128 b0 = _mm_loadu_ps (b + i);
153+
154+ // numpy's exact forward chain:
155+ // accum += a3*b3 (mul+add)
156+ // accum += a2*b2
157+ // temp = accum + a1*b1
158+ // accum = a0*b0 + temp
159+ v_accum = _mm_add_ps (v_accum, _mm_mul_ps (a3, b3));
160+ v_accum = _mm_add_ps (v_accum, _mm_mul_ps (a2, b2));
161+ __m128 t1 = _mm_add_ps (v_accum, _mm_mul_ps (a1, b1));
162+ v_accum = _mm_add_ps (_mm_mul_ps (a0, b0), t1);
163+ }
164+
165+ // Tail loop: numpy-style load_tillz_f32
166+ size_t remaining = n - i;
167+ while (remaining > 0 ) {
168+ __m128 va, vb;
169+ if (remaining >= 4 ) {
170+ va = _mm_loadu_ps (a + i);
171+ vb = _mm_loadu_ps (b + i);
172+ i += 4 ; remaining -= 4 ;
173+ } else if (remaining == 3 ) {
174+ // numpy's exact load pattern for 3 elements:
175+ // movq + movd + punpcklqdq
176+ __m128i ta = _mm_loadl_epi64 ((const __m128i*)(a + i));
177+ __m128i tb = _mm_loadl_epi64 ((const __m128i*)(b + i));
178+ ta = _mm_insert_epi32 (ta, ((const int *)(a + i))[2 ], 2 );
179+ tb = _mm_insert_epi32 (tb, ((const int *)(b + i))[2 ], 2 );
180+ va = _mm_castsi128_ps (ta);
181+ vb = _mm_castsi128_ps (tb);
182+ i += 3 ; remaining = 0 ;
183+ } else if (remaining == 2 ) {
184+ // numpy: movq
185+ va = _mm_castsi128_ps (_mm_loadl_epi64 ((const __m128i*)(a + i)));
186+ vb = _mm_castsi128_ps (_mm_loadl_epi64 ((const __m128i*)(b + i)));
187+ i += 2 ; remaining = 0 ;
188+ } else { // remaining == 1
189+ // numpy: movd
190+ va = _mm_castsi128_ps (_mm_cvtsi32_si128 (((const int *)(a + i))[0 ]));
191+ vb = _mm_castsi128_ps (_mm_cvtsi32_si128 (((const int *)(b + i))[0 ]));
192+ i += 1 ; remaining = 0 ;
193+ }
194+ v_accum = _mm_add_ps (v_accum, _mm_mul_ps (va, vb));
195+ }
196+
197+ // numpy: haddps(v,v); haddps(v,v)
198+ __m128 sum_halves = _mm_hadd_ps (v_accum, v_accum);
199+ return _mm_cvtss_f32 (_mm_hadd_ps (sum_halves, sum_halves));
200+ }
201+
69202inline std::string implicit_output_labels (const std::vector<std::string>& il) {
70203 std::map<char , int > cnt;
71204 for (const auto & s : il) {
@@ -158,7 +291,68 @@ void einsum(const std::string& subscripts,
158291 for (char c : output_labels)
159292 output_shape.push_back (label_size[c]);
160293
161- // Iteration space
294+ // ================================================================
295+ // Fast path: single contraction label that is the LAST axis in
296+ // BOTH operands (C-contiguous). Use BLAS dot per output element.
297+ // ================================================================
298+ if (sum_labels.size () == 1 ) {
299+ char clabel = sum_labels[0 ];
300+ ptrdiff_t csize = label_size[clabel];
301+
302+ int a_caxis = -1 , b_caxis = -1 ;
303+ for (const auto & [inp, ax] : label_axis[clabel]) {
304+ if (inp == 0 ) a_caxis = ax;
305+ if (inp == 1 ) b_caxis = ax;
306+ }
307+
308+ if (a_caxis == ndim[0 ] - 1 && b_caxis == ndim[1 ] - 1 && csize > 0 ) {
309+ auto a_str = compute_strides (shapes[0 ]);
310+ auto b_str = compute_strides (shapes[1 ]);
311+
312+ // Map each output label to its axis in operands A and B.
313+ // output_label → (a_axis_or_neg, b_axis_or_neg)
314+ vector<pair<int , int >> out_label_axes;
315+ for (char c : output_labels) {
316+ int aa = -1 , ba = -1 ;
317+ for (const auto & [inp, ax] : label_axis[c]) {
318+ if (inp == 0 ) aa = ax;
319+ if (inp == 1 ) ba = ax;
320+ }
321+ out_label_axes.push_back ({aa, ba});
322+ }
323+
324+ ptrdiff_t output_total = 1 ;
325+ for (ptrdiff_t s : output_shape) output_total *= s;
326+
327+ vector<ptrdiff_t > out_coord (output_shape.size (), 0 );
328+
329+ for (ptrdiff_t oi = 0 ; oi < output_total; ++oi) {
330+ ptrdiff_t rem = oi;
331+ for (int d = static_cast <int >(output_shape.size ()) - 1 ; d >= 0 ; --d) {
332+ out_coord[d] = rem % output_shape[d];
333+ rem /= output_shape[d];
334+ }
335+
336+ ptrdiff_t a_off = 0 , b_off = 0 ;
337+ for (size_t d = 0 ; d < out_label_axes.size (); ++d) {
338+ int aa = out_label_axes[d].first ;
339+ int ba = out_label_axes[d].second ;
340+ if (aa >= 0 ) a_off += out_coord[d] * a_str[aa];
341+ if (ba >= 0 ) b_off += out_coord[d] * b_str[ba];
342+ }
343+
344+ if constexpr (std::is_same_v<T, double >)
345+ result_ptr[oi] = einsum_reduce_f64 (a_ptr + a_off, b_ptr + b_off, static_cast <size_t >(csize));
346+ else
347+ result_ptr[oi] = einsum_reduce_f32 (a_ptr + a_off, b_ptr + b_off, static_cast <size_t >(csize));
348+ }
349+ return ;
350+ }
351+ }
352+
353+ // ================================================================
354+ // Scalar path: general case.
355+ // ================================================================
162356 vector<char > iter_labels = output_labels;
163357 iter_labels.insert (iter_labels.end (), sum_labels.begin (), sum_labels.end ());
164358 int n_iter = static_cast <int >(iter_labels.size ());
@@ -180,7 +374,6 @@ void einsum(const std::string& subscripts,
180374 iter_input_axis[li][inp] = ax;
181375 }
182376
183- // Compute total output size for validation
184377 vector<ptrdiff_t > iter_coord (n_iter, 0 );
185378 vector<ptrdiff_t > input_coord[2 ] = {
186379 vector<ptrdiff_t >(ndim[0 ], 0 ),
0 commit comments