Skip to content

perf(bb/msm): Jacobian (S, W) tree bucket reduction#23526

Draft
AztecBot wants to merge 7 commits into
zw/msm-webgpu-experiments-v2from
cb/cc66f83b5de3
Draft

perf(bb/msm): Jacobian (S, W) tree bucket reduction#23526
AztecBot wants to merge 7 commits into
zw/msm-webgpu-experiments-v2from
cb/cc66f83b5de3

Conversation

@AztecBot
Copy link
Copy Markdown
Collaborator

Motivation

The previous bucket-reduction stage runs one workgroup per window with
batch-affine Montgomery-trick adds. At small c (e.g. c=8) each
window has only 2^(c-1)=128 buckets to drain and the workgroup is
mostly idle, while batch-affine forces each thread to gather S=8
points before any work happens.

This PR replaces that stage with a fully parallel Jacobian tree
reduction that treats all NW * 2^(c-1) weighted buckets as one block
of work and dispatches every merge in parallel. Each thread does a
small constant-time merge with no batched inversion and no
S-point gather, so GPU saturation no longer depends on per-window
fan-out.

Algorithm

For each window w, each tree node summarises a contiguous range of
buckets as a pair (S, W):

  • S = Σ B[k] over the range,
  • W = Σ (pos · B[k]) with pos = 1..h relative to the range start.

Merging two adjacent (S_L, W_L) and (S_R, W_R) of size h each into
one (S, W) of size 2h:

S    = S_L + S_R                       (1 group add)
hS_R = double S_R, log2(h) times       (l Jacobian doublings)
W    = W_L + hS_R + W_R                (2 group adds)

At the root, W is L_w = Σ k · B_w[k] for k = 1..N. Per window the
total is ≈ 5N/2 adds + ≈ N doublings — fewer rounds than the old
4-phase reduction and with full inter-window parallelism throughout.

Round 0 (AA -> J, leaf) and rounds 1..c-2 (JJ -> J) are
implemented as separate shaders, because the first level can use the
cheaper mmadd/madd formulas (Z = 1 both inputs) and skip Jacobian
loads entirely. As required by the design, neither shader checks the
x1 == x2 case — SRS bases are randomly-independent generators and
coordinate collisions in the reduction are vanishingly improbable.

Layout

  • (S, W) pairs live in a 6-plane SoA buffer (S.X, S.Y, S.Z, W.X, W.Y, W.Z).
  • Two ping-pong buffers cover all rounds; the final round writes to a
    dedicated jbrFinalBuf whose plane stride equals NUM_WINDOWS, so
    the per-window L_w gather is a contiguous 3-plane copy into
    redStaging.
  • The host receives NW Jacobian points and Horner-combines them in
    Jacobian (one inversion at the very end). The C++ bridge consumer
    still sees per-window affine sums via windowSums (one inversion
    per window) for combineOnHost = false.

Files

  • wgsl/cuzk/jbr_aa_to_jj.template.wgsl — round 0 (AA -> J leaf merge)
  • wgsl/cuzk/jbr_jj_to_jj.template.wgsl — rounds 1..c-2 (JJ -> J merge)
  • cuzk/shader_manager.ts — two new gen_jbr_* shader builders
  • msm_v2.ts — replaces reduceInit + reduceLevel* dispatch with
    the AA->J + (c-2)×JJ->J schedule, rewires the L_w gather to the
    Jacobian buffer, and converts hostWindowCombine to take Jacobian
    inputs (the existing Horner already ran in Jacobian internally).
  • jbr_reference.test.mjs — standalone bigint reference that verifies
    the (S, W) merge formula matches Σ k · B[k] for c = 2..10.
  • dev/msm-webgpu/main.ts + scripts/run-browserstack.mjs — adds the
    ?autorun=msm-gpu-noble mode (and an --autorun flag on the
    driver) so a WebGPU-only correctness check against noble can run on
    BrowserStack without paying for the bb.js WASM bootstrap.

Test plan

  • node src/msm_webgpu/jbr_reference.test.mjs passes for c=2..10
    (bigint reference matches a naive Σ k · B[k] MSM).
  • Browser correctness via ?autorun=msm-gpu-noble&logn=14 on a
    BrowserStack macOS Sequoia / Chrome target (M2 Mini).
  • Bench ?autorun=msm-cross-check once a slot frees up to measure
    c=8 and c=13 against the existing batch-affine reduction.

Created by claudebox · group: slackbot

@AztecBot AztecBot added the claudebox Owned by claudebox. it can push to this PR. label May 23, 2026
AztecBot added 6 commits May 23, 2026 10:10
bucket_result stores (0,0) for empty buckets; the JBR formulas were
feeding those non-points into mmadd/madd/add-2007-bl and propagating
garbage through the tree. Adds per-node `meta = is_present | unitp<<1`:
empty leaves fall through case (0,0); case (1,0)/(0,1) at round 0 lift
the single bucket to Jacobian with unitp = 1/2.

Round r >= 1 case (0,1) needed extra care — when the right child is a
single-bucket subtree with R.unitp == h (= 2^r), `h*R.S` and `R.W` are
both the Jacobian form of the SAME 2^r·B[k] and the standard jacAdd
hits the doubling case. Detect via meta and shortcut to jacDouble(R.W);
the standard formula remains safe everywhere else (multi-bucket R mixes
distinct generators; unit R with R.unitp != h gives distinct group
elements).

Host-side hostWindowCombine now uses jacAddSafe / jacDoubleSafe + a
JAC_INF sentinel so all-empty windows propagate as point-at-infinity
through the Horner.

`jbr_reference.test.mjs` extended with empty-bucket cases
(top-empty, first-empty, sparse-top, sparse-mid, alternating,
all-empty). sparse-mid exercises the unitp==h collision chain that
the BrowserStack run originally exposed.
#2 from the audit: manually inline jac_add / jac_double in the AA->J and
JJ->J shaders, break the case-(1,1) hot path into scoped stages, and
defer loads to first use. WGSL doesn't guarantee that function calls
with array<u32,8> by-value parameters inline cleanly; the previous
shader passed 6 such arrays per jac_add call and three jacAdd calls
per merge, easily exceeding the per-thread vector-register budget on
Adreno (and squeezing it on M2).

The JJ hot path now has four scoped stages — S, doublings, W_tmp, W_new
— with sl*/sr*/wl*/wr* loaded inside the stage that consumes them.
Stage outputs that bridge stages (dx/dy/dz, wtx/wty/wtz) are outer-scope
vars rewritten in place; stage-internal intermediates fall out of the
live set at the closing brace.

#1 from the audit: pickReduceWg is now a flat 128 regardless of c. The
old c-tiered table (32/64/128) was tuned for the batch-affine kernel
where workgroup size capped at 2^(c-1); the new flat-tree dispatch
doesn't have that constraint, and 128-thread WGs occupy a core fully
without leaving simdgroups idle on late, sparse rounds.

Reference test (jbr_reference.test.mjs) unchanged and still passes.
GPU correctness was last verified on Apple M2; an S25 (Adreno 750)
bench follows.

Also brings the autorun=msm-gpu-bench page mode in dev/msm-webgpu so
the next sweep can be driven from a single BrowserStack URL.
…iler limit

The previous commit inlined the Jacobian formulas straight into the AA and
JJ merge kernels. On the Apple M2 this compiles fine; on a Samsung
Galaxy S25 (Adreno 750 / Snapdragon 8 Elite) the Vulkan driver returns
VK_ERROR_UNKNOWN from CreateComputePipelines — the post-mustache shader
exceeds Adreno's per-kernel size/code-cache budget. Reverting both
shaders to the function-call form; pickReduceWg flat 128 stays.

S25 with the functional form runs the bench end-to-end at logN=14, c=8:
wall_min 21.4 ms / wall_mean 29.3 ms (thermal throttling visible across
the 20-sample run). redLevel = 2.08 ms (min), redInit = 0.07 ms (min).

Will attack register pressure via the field-element representation
instead — vec4<u32>×2 in place of array<u32,8> doesn't grow the shader
text, and it gives the compiler 2 vector-register slots per field
instead of 8 scalar slots.
…loads

Two pieces aimed at reducing per-thread register pressure and the
per-round kernel-dispatch overhead:

1. jbr_window_coop kernel — for c ≤ 8 the per-window (S, W) tree fits
   N/2 ≤ 64 nodes × 196 B per node = 12.5 KiB inside a workgroup's
   threadgroup memory (under the 16 KiB WebGPU spec minimum). One
   workgroup owns one window: thread tid does the AA→J pair-merge,
   workgroupBarrier, then each subsequent JJ→J sub-round halves the
   active thread count with another barrier in between. Replaces the
   ~7 individual dispatches the c=8 redLevel previously needed with a
   single dispatch (workgroup_size = N/2, dispatch_count = NW).
   MsmV2.create gates on c ≤ 8; larger c stays on the multi-dispatch
   JJ path, which would breach the 16 KiB TG budget.

2. JJ shader — defer wl* / wr* loads to the scoped block that
   consumes them. With every load up-front the wl* and wr* fields
   stayed live throughout the S-stage jac_add, inflating the
   per-thread peak live-set. Loading inside `{}` so the compiler can
   drop the temporaries at scope close. (Inlining failed Adreno's
   shader compiler last attempt — staying functional here.)

Both changes preserve the existing case-split for empty buckets and
single-bucket subtree positions; the jbr_reference.test.mjs cases
(sparse-top/mid, alternating, all-empty) still pass.

Apple M2 (chrome 148) noble check at logN=14 still matches.
Samsung S25 (Adreno 750) bench pending.
S25 spotted: 'active' is reserved in current WGSL. Trivial rename of the
loop counter in jbr_window_coop.
S25 bench (logN=14, c=8) ran the WG-coop kernel and hit wall_min 293 ms
(redInit alone 257 ms) vs 21.5 ms baseline. The combined single-dispatch
shader with 6 sub-rounds + 6 case-splits compiles cleanly but Adreno's
register allocator clearly spills almost everything to global memory.

Keep the shader + plumbing; just gate the host pipeline compilation off
(). The multi-dispatch JJ path remains the active code
path for every c. The deferred wl* / wr* loads from the previous commit
benefit that path directly and weren't being exercised because c=8 was
taking the WG-coop branch.
EOF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

claudebox Owned by claudebox. it can push to this PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant