Skip to content

Commit d9d5a55

Browse files
author
peng.li24
committed
refactor: eliminate stack/heap patterns in einsum scalar path and fast path (-55 lines)
1 parent c984454 commit d9d5a55

1 file changed

Lines changed: 20 additions & 61 deletions

File tree

numpy/einsum.h

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,7 @@ void einsum(const std::string& subscripts,
346346
auto b_str = compute_strides(shapes[1]);
347347
int n_out = static_cast<int>(output_labels.size());
348348

349-
// Stack for small dimension counts (common: coupling vectors 1-3 dims)
350-
#define D8 8
351-
ptrdiff_t stk_stepA[D8], stk_stepB[D8], stk_ostride[D8];
352-
vector<ptrdiff_t> heap_stepA, heap_stepB, heap_ostride;
353-
ptrdiff_t* stepA = (n_out <= D8) ? stk_stepA : (heap_stepA.resize(n_out), heap_stepA.data());
354-
ptrdiff_t* stepB = (n_out <= D8) ? stk_stepB : (heap_stepB.resize(n_out), heap_stepB.data());
355-
ptrdiff_t* ostride = (n_out <= D8) ? stk_ostride : (heap_ostride.resize(n_out), heap_ostride.data());
356-
349+
vector<ptrdiff_t> stepA(n_out), stepB(n_out), ostride(n_out);
357350
for (int d = 0; d < n_out; ++d) {
358351
char c = output_labels[d];
359352
int aa = -1, ba = -1;
@@ -388,15 +381,12 @@ void einsum(const std::string& subscripts,
388381
else
389382
result_ptr[oi] = einsum_reduce_f32(a_ptr + a_off, b_ptr + b_off, csize_u);
390383
}
391-
#undef D8
392384
return;
393385
}
394386
}
395387

396388
// ================================================================
397-
// Scalar path: general case.
398-
// - Iterative stride walk (no division per flat index)
399-
// - Stack allocation for small dimension counts
389+
// Scalar path: general case — iterative stride walk, no division.
400390
// ================================================================
401391
vector<char> iter_labels = output_labels;
402392
iter_labels.insert(iter_labels.end(), sum_labels.begin(), sum_labels.end());
@@ -407,47 +397,22 @@ void einsum(const std::string& subscripts,
407397
for (int i = 0; i < n_iter; ++i)
408398
iter_sizes[i] = label_size[iter_labels[i]];
409399

410-
auto iter_strides = compute_strides(iter_sizes);
411-
auto strides_a = compute_strides(shapes[0]);
412-
auto strides_b = compute_strides(shapes[1]);
400+
auto strides_a = compute_strides(shapes[0]);
401+
auto strides_b = compute_strides(shapes[1]);
413402

414-
// iter_input_axis: [li][inp] = axis index or -1
415-
#define MAX_DIM 16
416-
int stack_iax[MAX_DIM][2];
417-
vector<int> heap_iax;
418-
int (*iax)[2];
419-
if (n_iter <= MAX_DIM) {
420-
for (int li = 0; li < n_iter; ++li)
421-
stack_iax[li][0] = stack_iax[li][1] = -1;
422-
for (int li = 0; li < n_iter; ++li) {
423-
char l = iter_labels[li];
424-
for (const auto& [inp, ax] : label_axis[l])
425-
stack_iax[li][inp] = ax;
426-
}
427-
iax = stack_iax;
428-
} else {
429-
heap_iax.assign(n_iter * 2, -1);
430-
for (int li = 0; li < n_iter; ++li) {
431-
char l = iter_labels[li];
432-
for (const auto& [inp, ax] : label_axis[l])
433-
heap_iax[li * 2 + inp] = ax;
434-
}
435-
iax = reinterpret_cast<int(*)[2]>(heap_iax.data());
403+
// iter_input_axis: iax[li][inp] = axis index or -1
404+
vector<int> iax_flat(n_iter * 2, -1);
405+
for (int li = 0; li < n_iter; ++li) {
406+
char l = iter_labels[li];
407+
for (const auto& [inp, ax] : label_axis[l])
408+
iax_flat[li * 2 + inp] = ax;
436409
}
410+
int (*iax)[2] = reinterpret_cast<int(*)[2]>(iax_flat.data());
437411

438-
// iter_coord
439-
ptrdiff_t stack_coord[MAX_DIM] = {};
440-
vector<ptrdiff_t> heap_coord;
441-
ptrdiff_t* iter_coord = (n_iter <= MAX_DIM) ? stack_coord
442-
: (heap_coord.assign(n_iter, 0), heap_coord.data());
443-
444-
// input_coord for each operand
445-
ptrdiff_t stack_ic[2][MAX_DIM] = {};
446-
vector<ptrdiff_t> heap_ic[2];
447-
ptrdiff_t* ic[2];
448-
for (int inp = 0; inp < 2; ++inp)
449-
ic[inp] = (ndim[inp] <= MAX_DIM) ? stack_ic[inp]
450-
: (heap_ic[inp].assign(ndim[inp], 0), heap_ic[inp].data());
412+
// iter_coord and input_coord
413+
vector<ptrdiff_t> iter_coord(n_iter, 0);
414+
vector<ptrdiff_t> ic0(ndim[0], 0), ic1(ndim[1], 0);
415+
ptrdiff_t* ic[2] = {ic0.data(), ic1.data()};
451416

452417
ptrdiff_t iter_total = 1;
453418
for (ptrdiff_t s : iter_sizes) iter_total *= s;
@@ -456,13 +421,13 @@ void einsum(const std::string& subscripts,
456421
T accumulator = T(0);
457422

458423
for (ptrdiff_t flat = 0; flat < iter_total; ++flat) {
459-
// Iterative coordinate update (faster than division)
424+
// Iterative coordinate update
460425
if (flat > 0)
461426
for (int d = n_iter - 1; d >= 0; --d)
462427
if (++iter_coord[d] < iter_sizes[d]) break;
463428
else iter_coord[d] = 0;
464429

465-
// Check if this is the start of a new output element
430+
// Start of new output element?
466431
bool is_new_output = true;
467432
for (int i = n_output; i < n_iter; ++i)
468433
if (iter_coord[i] != 0) { is_new_output = false; break; }
@@ -480,14 +445,9 @@ void einsum(const std::string& subscripts,
480445
accumulator = T(0);
481446
}
482447

483-
// Reset input_coord (memset is fast for small ndim — common case)
484-
for (int inp = 0; inp < 2; ++inp)
485-
if (ndim[inp] <= MAX_DIM)
486-
memset(ic[inp], 0, static_cast<size_t>(ndim[inp]) * sizeof(ptrdiff_t));
487-
else
488-
std::fill_n(ic[inp], ndim[inp], ptrdiff_t(0));
489-
490-
// Map iter_coord → input_coord
448+
// Reset input_coord and map iter_coord → input_coord
449+
memset(ic[0], 0, static_cast<size_t>(ndim[0]) * sizeof(ptrdiff_t));
450+
memset(ic[1], 0, static_cast<size_t>(ndim[1]) * sizeof(ptrdiff_t));
491451
for (int li = 0; li < n_iter; ++li)
492452
for (int inp = 0; inp < 2; ++inp) {
493453
int ax = iax[li][inp];
@@ -497,7 +457,6 @@ void einsum(const std::string& subscripts,
497457
accumulator += a_ptr[flat_index(ic[0], strides_a.data(), ndim[0])]
498458
* b_ptr[flat_index(ic[1], strides_b.data(), ndim[1])];
499459
}
500-
#undef MAX_DIM
501460

502461
if (current_output_idx >= 0)
503462
result_ptr[current_output_idx] = accumulator;

0 commit comments

Comments
 (0)