@@ -346,14 +346,7 @@ void einsum(const std::string& subscripts,
346346 auto b_str = compute_strides (shapes[1 ]);
347347 int n_out = static_cast <int >(output_labels.size ());
348348
349- // Stack for small dimension counts (common: coupling vectors 1-3 dims)
350- #define D8 8
351- ptrdiff_t stk_stepA[D8 ], stk_stepB[D8 ], stk_ostride[D8 ];
352- vector<ptrdiff_t > heap_stepA, heap_stepB, heap_ostride;
353- ptrdiff_t * stepA = (n_out <= D8 ) ? stk_stepA : (heap_stepA.resize (n_out), heap_stepA.data ());
354- ptrdiff_t * stepB = (n_out <= D8 ) ? stk_stepB : (heap_stepB.resize (n_out), heap_stepB.data ());
355- ptrdiff_t * ostride = (n_out <= D8 ) ? stk_ostride : (heap_ostride.resize (n_out), heap_ostride.data ());
356-
349+ vector<ptrdiff_t > stepA (n_out), stepB (n_out), ostride (n_out);
357350 for (int d = 0 ; d < n_out; ++d) {
358351 char c = output_labels[d];
359352 int aa = -1 , ba = -1 ;
@@ -388,15 +381,12 @@ void einsum(const std::string& subscripts,
388381 else
389382 result_ptr[oi] = einsum_reduce_f32 (a_ptr + a_off, b_ptr + b_off, csize_u);
390383 }
391- #undef D8
392384 return ;
393385 }
394386 }
395387
396388 // ================================================================
397- // Scalar path: general case.
398- // - Iterative stride walk (no division per flat index)
399- // - Stack allocation for small dimension counts
389+ // Scalar path: general case — iterative stride walk, no division.
400390 // ================================================================
401391 vector<char > iter_labels = output_labels;
402392 iter_labels.insert (iter_labels.end (), sum_labels.begin (), sum_labels.end ());
@@ -407,47 +397,22 @@ void einsum(const std::string& subscripts,
407397 for (int i = 0 ; i < n_iter; ++i)
408398 iter_sizes[i] = label_size[iter_labels[i]];
409399
410- auto iter_strides = compute_strides (iter_sizes);
411- auto strides_a = compute_strides (shapes[0 ]);
412- auto strides_b = compute_strides (shapes[1 ]);
400+ auto strides_a = compute_strides (shapes[0 ]);
401+ auto strides_b = compute_strides (shapes[1 ]);
413402
414- // iter_input_axis: [li][inp] = axis index or -1
415- #define MAX_DIM 16
416- int stack_iax[MAX_DIM ][2 ];
417- vector<int > heap_iax;
418- int (*iax)[2 ];
419- if (n_iter <= MAX_DIM ) {
420- for (int li = 0 ; li < n_iter; ++li)
421- stack_iax[li][0 ] = stack_iax[li][1 ] = -1 ;
422- for (int li = 0 ; li < n_iter; ++li) {
423- char l = iter_labels[li];
424- for (const auto & [inp, ax] : label_axis[l])
425- stack_iax[li][inp] = ax;
426- }
427- iax = stack_iax;
428- } else {
429- heap_iax.assign (n_iter * 2 , -1 );
430- for (int li = 0 ; li < n_iter; ++li) {
431- char l = iter_labels[li];
432- for (const auto & [inp, ax] : label_axis[l])
433- heap_iax[li * 2 + inp] = ax;
434- }
435- iax = reinterpret_cast <int (*)[2 ]>(heap_iax.data ());
403+ // iter_input_axis: iax[li][inp] = axis index or -1
404+ vector<int > iax_flat (n_iter * 2 , -1 );
405+ for (int li = 0 ; li < n_iter; ++li) {
406+ char l = iter_labels[li];
407+ for (const auto & [inp, ax] : label_axis[l])
408+ iax_flat[li * 2 + inp] = ax;
436409 }
410+ int (*iax)[2 ] = reinterpret_cast <int (*)[2 ]>(iax_flat.data ());
437411
438- // iter_coord
439- ptrdiff_t stack_coord[MAX_DIM ] = {};
440- vector<ptrdiff_t > heap_coord;
441- ptrdiff_t * iter_coord = (n_iter <= MAX_DIM ) ? stack_coord
442- : (heap_coord.assign (n_iter, 0 ), heap_coord.data ());
443-
444- // input_coord for each operand
445- ptrdiff_t stack_ic[2 ][MAX_DIM ] = {};
446- vector<ptrdiff_t > heap_ic[2 ];
447- ptrdiff_t * ic[2 ];
448- for (int inp = 0 ; inp < 2 ; ++inp)
449- ic[inp] = (ndim[inp] <= MAX_DIM ) ? stack_ic[inp]
450- : (heap_ic[inp].assign (ndim[inp], 0 ), heap_ic[inp].data ());
412+ // iter_coord and input_coord
413+ vector<ptrdiff_t > iter_coord (n_iter, 0 );
414+ vector<ptrdiff_t > ic0 (ndim[0 ], 0 ), ic1 (ndim[1 ], 0 );
415+ ptrdiff_t * ic[2 ] = {ic0.data (), ic1.data ()};
451416
452417 ptrdiff_t iter_total = 1 ;
453418 for (ptrdiff_t s : iter_sizes) iter_total *= s;
@@ -456,13 +421,13 @@ void einsum(const std::string& subscripts,
456421 T accumulator = T (0 );
457422
458423 for (ptrdiff_t flat = 0 ; flat < iter_total; ++flat) {
459- // Iterative coordinate update (faster than division)
424+ // Iterative coordinate update
460425 if (flat > 0 )
461426 for (int d = n_iter - 1 ; d >= 0 ; --d)
462427 if (++iter_coord[d] < iter_sizes[d]) break ;
463428 else iter_coord[d] = 0 ;
464429
465- // Check if this is the start of a new output element
430+ // Start of new output element?
466431 bool is_new_output = true ;
467432 for (int i = n_output; i < n_iter; ++i)
468433 if (iter_coord[i] != 0 ) { is_new_output = false ; break ; }
@@ -480,14 +445,9 @@ void einsum(const std::string& subscripts,
480445 accumulator = T (0 );
481446 }
482447
483- // Reset input_coord (memset is fast for small ndim — common case)
484- for (int inp = 0 ; inp < 2 ; ++inp)
485- if (ndim[inp] <= MAX_DIM )
486- memset (ic[inp], 0 , static_cast <size_t >(ndim[inp]) * sizeof (ptrdiff_t ));
487- else
488- std::fill_n (ic[inp], ndim[inp], ptrdiff_t (0 ));
489-
490- // Map iter_coord → input_coord
448+ // Reset input_coord and map iter_coord → input_coord
449+ memset (ic[0 ], 0 , static_cast <size_t >(ndim[0 ]) * sizeof (ptrdiff_t ));
450+ memset (ic[1 ], 0 , static_cast <size_t >(ndim[1 ]) * sizeof (ptrdiff_t ));
491451 for (int li = 0 ; li < n_iter; ++li)
492452 for (int inp = 0 ; inp < 2 ; ++inp) {
493453 int ax = iax[li][inp];
@@ -497,7 +457,6 @@ void einsum(const std::string& subscripts,
497457 accumulator += a_ptr[flat_index (ic[0 ], strides_a.data (), ndim[0 ])]
498458 * b_ptr[flat_index (ic[1 ], strides_b.data (), ndim[1 ])];
499459 }
500- #undef MAX_DIM
501460
502461 if (current_output_idx >= 0 )
503462 result_ptr[current_output_idx] = accumulator;
0 commit comments