From 9e36d793d30af4e72e0384cd7429e40e04996b05 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Wed, 17 Jun 2026 00:24:11 -0700 Subject: [PATCH 1/2] feat(compute): on-device bf16 broadcast + scalar ops (capture-safe) bf16 broadcast binary ops (Add/Sub/Mul/Div with mismatched shapes) and scalar ops (AddScalar/MulScalar/DivScalar) fell back to the CPU engine, whose host memcpy breaks CUDA-graph capture. This blocked capture-ON bf16 CrossAsset training at QKL2Norm, which does a column-broadcast Mul(x, inv) and AddScalar(eps). Route these bf16 ops through on-device f32 conversion instead of the CPU: convert operands bf16->f32 (existing BF16ToF32 kernel), run the existing f32 broadcast/ scalar kernel, convert the result f32->bf16 -- all on the engine stream, so the op stays on the GPU and capturable. Computing in f32 matches the bf16 GEMM/reduction convention (f32 accumulation, bf16 storage). No new CUDA kernels and no .so change: reuses the f32 kernels (raw-pointer ABI) + the existing conversion kernels. Scratch is arena-allocated and freed stream-ordered, the same pattern as getDevicePtr's FP16->F32 scratch. - gpuBroadcastOp: both kernel-exec sites go through execBroadcast2D (bf16 -> convert-via-f32, f32 -> direct). - gpuBroadcast4DOp: signals not-handled for bf16 (B=1 CrossAsset uses only 2D broadcasts; bf16 4D broadcast is a follow-up) so it never runs the f32 4D kernel on 2-byte data. - gpuAddScalar/gpuMulScalar/gpuDivScalar: bf16 -> gpuScalarOpBF16. CUDA-gated parity test: bf16 column-broadcast Mul + AddScalar vs f32 reference. Completes the bf16 GPU op surface for capture-ON CrossAsset training. ADR-075 L4. --- compute/gpu_bf16.go | 116 ++++++++++++++++++++++++ compute/gpu_bf16_transpose_gemm_test.go | 44 +++++++++ compute/gpu_kernels.go | 51 +++++------ 3 files changed, 184 insertions(+), 27 deletions(-) diff --git a/compute/gpu_bf16.go b/compute/gpu_bf16.go index fe0e121..6858c26 100644 --- a/compute/gpu_bf16.go +++ b/compute/gpu_bf16.go @@ -82,6 +82,122 @@ func gpuBinaryOpBF16[T tensor.Numeric]( return makeGPUResult[T](e, a.Shape(), devC, n, dst...) } +// execBroadcast2D runs a 2D broadcast binary kernel and builds the result. +// +// For f32 it calls the kernel directly (prior behavior). For bf16 -- which has +// no native broadcast kernel -- it converts both operands to f32 on-device, +// runs the existing f32 broadcast kernel, then converts the result back to bf16, +// all on the engine stream. This keeps bf16 broadcast ops on the GPU and +// CUDA-graph-capturable, instead of the host CPU fallback whose D2H/H2D copies +// break stream capture (e.g. QKL2Norm's Mul(x, inv)). Computing in f32 also +// matches the bf16 GEMM/reduction convention (f32 accumulation, bf16 storage). +// devA/devB are T-typed device pointers; nA/nB are the element counts of a/b. +func execBroadcast2D[T tensor.Numeric]( + e *GPUEngine[T], outShape []int, + devA unsafe.Pointer, nA int, devB unsafe.Pointer, nB int, outElems int, + saRow, saCol, sbRow, sbCol, mDim, dDim int, + kernelFn func(devA, devB, devC unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream gpuapi.Stream) error, + dst ...*tensor.TensorNumeric[T], +) (*tensor.TensorNumeric[T], error) { + if !isBFloat16[T]() { + devC, err := e.pool.Alloc(e.deviceID, outElems*f32Size) + if err != nil { + return nil, err + } + if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, mDim, dDim, e.stream); err != nil { + e.pool.Free(e.deviceID, devC, outElems*f32Size) + return nil, err + } + return makeGPUResult[T](e, outShape, devC, outElems, dst...) + } + + // bf16: a,b -> f32 scratch, f32 broadcast, result -> bf16. Scratch is freed + // after the kernels are enqueued on the engine stream (arena frees are + // stream-ordered, matching the getDevicePtr FP16->F32 scratch pattern). + aF, err := e.pool.Alloc(e.deviceID, nA*f32Size) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, aF, nA*f32Size) + if err := e.kernels.BF16ToF32(devA, aF, nA, e.stream); err != nil { + return nil, err + } + bF, err := e.pool.Alloc(e.deviceID, nB*f32Size) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, bF, nB*f32Size) + if err := e.kernels.BF16ToF32(devB, bF, nB, e.stream); err != nil { + return nil, err + } + cF, err := e.pool.Alloc(e.deviceID, outElems*f32Size) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, cF, outElems*f32Size) + if err := kernelFn(aF, bF, cF, saRow, saCol, sbRow, sbCol, mDim, dDim, e.stream); err != nil { + return nil, err + } + cB, err := e.pool.Alloc(e.deviceID, outElems*bf16Size) + if err != nil { + return nil, err + } + if err := e.kernels.F32ToBF16(cF, cB, outElems, e.stream); err != nil { + e.pool.Free(e.deviceID, cB, outElems*bf16Size) + return nil, err + } + return makeGPUResult[T](e, outShape, cB, outElems, dst...) +} + +// bf16ScalarToF32 converts a bf16 scalar (carried as T) to float32. Used by the +// bf16 scalar-op path, where toFloat32 (an any.(float32) assertion) would panic. +func bf16ScalarToF32[T tensor.Numeric](v T) float32 { + return any(v).(float16.BFloat16).ToFloat32() +} + +// gpuScalarOpBF16 runs a scalar kernel (c = op(a, scalar)) for bf16 by converting +// a to f32 on-device, running the f32 scalar kernel, and converting the result +// back to bf16 -- keeping the op on the GPU and capture-safe (the CPU fallback's +// host copies break CUDA-graph capture, e.g. QKL2Norm's AddScalar(eps)). +func gpuScalarOpBF16[T tensor.Numeric]( + e *GPUEngine[T], a *tensor.TensorNumeric[T], scalar float32, + kernelFn func(devA unsafe.Pointer, scalar float32, devC unsafe.Pointer, n int, stream gpuapi.Stream) error, + dst ...*tensor.TensorNumeric[T], +) (*tensor.TensorNumeric[T], error) { + e.setDevice() + n := a.GetStorage().Len() + devA, cleanupA, err := getDevicePtr(e, a) + if err != nil { + return nil, err + } + defer cleanupA() + aF, err := e.pool.Alloc(e.deviceID, n*f32Size) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, aF, n*f32Size) + if err := e.kernels.BF16ToF32(devA, aF, n, e.stream); err != nil { + return nil, err + } + cF, err := e.pool.Alloc(e.deviceID, n*f32Size) + if err != nil { + return nil, err + } + defer e.pool.Free(e.deviceID, cF, n*f32Size) + if err := kernelFn(aF, scalar, cF, n, e.stream); err != nil { + return nil, err + } + cB, err := e.pool.Alloc(e.deviceID, n*bf16Size) + if err != nil { + return nil, err + } + if err := e.kernels.F32ToBF16(cF, cB, n, e.stream); err != nil { + e.pool.Free(e.deviceID, cB, n*bf16Size) + return nil, err + } + return makeGPUResult[T](e, a.Shape(), cB, n, dst...) +} + // gpuSoftmaxBF16 runs a native bf16 softmax along the given axis using the // fused scaled-softmax kernel (scale = 1.0) with FP32 max/sum accumulation. func gpuSoftmaxBF16[T tensor.Numeric]( diff --git a/compute/gpu_bf16_transpose_gemm_test.go b/compute/gpu_bf16_transpose_gemm_test.go index a4ca7b2..accc402 100644 --- a/compute/gpu_bf16_transpose_gemm_test.go +++ b/compute/gpu_bf16_transpose_gemm_test.go @@ -233,6 +233,50 @@ func TestGPUBF16_ReshapeStaysOnDevice(t *testing.T) { } } +func TestGPUBF16_BroadcastAndScalarParity(t *testing.T) { + eng := newTestGPUBF16Engine(t) + ctx := context.Background() + + // These are the QKL2Norm ops that broke CUDA-graph capture by falling to the + // CPU engine: a column-broadcast Mul (x[M,D] * inv[M,1]) and AddScalar. + const M, D = 6, 4 + xv := ramp(M * D) + invv := make([]float32, M) // [M,1] column vector + for i := range invv { + invv[i] = float32((i%5)+1) * 0.25 // bf16-exact + } + x := bf16Tensor(t, []int{M, D}, xv) + inv := bf16Tensor(t, []int{M, 1}, invv) + + // Broadcast Mul: out[i,j] = x[i,j] * inv[i]. + got, err := eng.Mul(ctx, x, inv) + if err != nil { + t.Fatalf("broadcast Mul: %v", err) + } + if want := []int{M, D}; !shapeEq(got.Shape(), want) { + t.Fatalf("broadcast Mul shape = %v, want %v", got.Shape(), want) + } + gd := bf16ToF32(got.Data()) + for i := 0; i < M; i++ { + for j := 0; j < D; j++ { + want := float16.BFloat16FromFloat32(xv[i*D+j] * invv[i]).ToFloat32() + assertBF16Close(t, "broadcastMul", i*D+j, gd[i*D+j], want, 2.0) + } + } + + // AddScalar: out[i] = x[i] + eps. + eps := float16.BFloat16FromFloat32(0.125) + gotS, err := eng.AddScalar(ctx, x, eps) + if err != nil { + t.Fatalf("AddScalar: %v", err) + } + sd := bf16ToF32(gotS.Data()) + for i := range xv { + want := float16.BFloat16FromFloat32(xv[i] + 0.125).ToFloat32() + assertBF16Close(t, "addScalar", i, sd[i], want, 2.0) + } +} + // ramp returns n small, bf16-exact values centered near zero so GEMM sums stay // well-conditioned (no catastrophic cancellation). func ramp(n int) []float32 { diff --git a/compute/gpu_kernels.go b/compute/gpu_kernels.go index 60252c2..2f7bd1a 100644 --- a/compute/gpu_kernels.go +++ b/compute/gpu_kernels.go @@ -390,6 +390,13 @@ func gpuBroadcast4DOp[T tensor.Numeric]( if !ok { return nil, nil // signal: not handled } + // The f32 4D broadcast kernel cannot run on bf16 (2-byte) device pointers. + // Signal "not handled" so the caller takes the CPU fallback. (The B=1 + // CrossAsset graph uses only 2D broadcasts, handled on-device via + // execBroadcast2D; bf16 4D broadcast support is a follow-up.) + if isBFloat16[T]() { + return nil, nil + } outShape := broadcastShape(aShape, bShape) outElems := 1 @@ -485,18 +492,8 @@ func gpuBroadcastOp[T tensor.Numeric]( defer cleanupB() outElems := M * D - byteSize := outElems * f32Size - devC, err := e.pool.Alloc(e.deviceID, byteSize) - if err != nil { - return nil, err - } - - if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, M, D, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, byteSize) - return nil, err - } - - return makeGPUResult[T](e, outShape, devC, outElems, dst...) + return execBroadcast2D(e, outShape, devA, aTotal, devB, bTotal, + outElems, saRow, saCol, sbRow, sbCol, M, D, kernelFn, dst...) } } @@ -597,18 +594,9 @@ func gpuBroadcastOp[T tensor.Numeric]( return nil, err } defer cleanupB() - byteSize := outElems * f32Size - devC, err := e.pool.Alloc(e.deviceID, byteSize) - if err != nil { - return nil, err - } - if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, M, D, e.stream); err != nil { - e.pool.Free(e.deviceID, devC, byteSize) - return nil, err - } - - return makeGPUResult[T](e, outShape, devC, outElems, dst...) + return execBroadcast2D(e, outShape, devA, totalElements(aShape), devB, totalElements(bShape), + outElems, saRow, saCol, sbRow, sbCol, M, D, kernelFn, dst...) } // flattenTo2D flattens an N-D shape to [M, D] where M = product of all dims except last, D = last dim. @@ -810,7 +798,7 @@ func (e *GPUEngine[T]) gpuAdd(ctx context.Context, a, b *tensor.TensorNumeric[T] if sameShape(a, b) { return gpuBinaryOpBF16(e, a, b, e.kernels.AddBF16, dst...) } - return e.cpu.Add(ctx, a, b, dst...) + return gpuBroadcastOp(e, ctx, a, b, e.kernels.AddBroadcast, e.kernels.AddBroadcast4D, e.cpu.Add, dst...) } if !isFloat32[T]() { return e.cpu.Add(ctx, a, b, dst...) @@ -835,7 +823,7 @@ func (e *GPUEngine[T]) gpuSub(ctx context.Context, a, b *tensor.TensorNumeric[T] if sameShape(a, b) { return gpuBinaryOpBF16(e, a, b, e.kernels.SubBF16, dst...) } - return e.cpu.Sub(ctx, a, b, dst...) + return gpuBroadcastOp(e, ctx, a, b, e.kernels.SubBroadcast, e.kernels.SubBroadcast4D, e.cpu.Sub, dst...) } if !isFloat32[T]() { return e.cpu.Sub(ctx, a, b, dst...) @@ -860,7 +848,7 @@ func (e *GPUEngine[T]) gpuMul(ctx context.Context, a, b *tensor.TensorNumeric[T] if sameShape(a, b) { return gpuBinaryOpBF16(e, a, b, e.kernels.MulBF16, dst...) } - return e.cpu.Mul(ctx, a, b, dst...) + return gpuBroadcastOp(e, ctx, a, b, e.kernels.MulBroadcast, e.kernels.MulBroadcast4D, e.cpu.Mul, dst...) } if !isFloat32[T]() { return e.cpu.Mul(ctx, a, b, dst...) @@ -885,7 +873,7 @@ func (e *GPUEngine[T]) gpuDiv(ctx context.Context, a, b *tensor.TensorNumeric[T] if sameShape(a, b) { return gpuBinaryOpBF16(e, a, b, e.kernels.DivBF16, dst...) } - return e.cpu.Div(ctx, a, b, dst...) + return gpuBroadcastOp(e, ctx, a, b, e.kernels.DivBroadcast, e.kernels.DivBroadcast4D, e.cpu.Div, dst...) } if !isFloat32[T]() { return e.cpu.Div(ctx, a, b, dst...) @@ -1045,6 +1033,9 @@ func (e *GPUEngine[T]) gpuTanhPrime(ctx context.Context, a, upstream *tensor.Ten } func (e *GPUEngine[T]) gpuAddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.AddScalar, dst...) + } if !isFloat32[T]() { return e.cpu.AddScalar(ctx, a, scalar, dst...) } @@ -1055,6 +1046,9 @@ func (e *GPUEngine[T]) gpuAddScalar(ctx context.Context, a *tensor.TensorNumeric } func (e *GPUEngine[T]) gpuMulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.MulScalar, dst...) + } if !isFloat32[T]() { return e.cpu.MulScalar(ctx, a, scalar, dst...) } @@ -1065,6 +1059,9 @@ func (e *GPUEngine[T]) gpuMulScalar(ctx context.Context, a *tensor.TensorNumeric } func (e *GPUEngine[T]) gpuDivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.DivScalar, dst...) + } if !isFloat32[T]() { return e.cpu.DivScalar(ctx, a, scalar, dst...) } From 82468fec54a4393046fee4b597c1e24ef62ab723 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Wed, 17 Jun 2026 01:00:25 -0700 Subject: [PATCH 2/2] fix(gradcheck): pin CrossAttention fwd intermediates across arena reset The CrossAttention test op (added in cc6948a) captured its Forward intermediates (softmax weights attn + Q/K/V) via closure and assumed no arena reset between Forward and Backward. The testing/parity reset-between-fwd-bwd schedule DOES Reset the arena between the passes, so those arena-backed tensors were freed and Backward read garbage -- TestRun_HostArenaStress_RegistryGreen failed with CrossAttention bwd max_abs=+Inf (regressed on main; passes at v1.15.0..v1.17.1). Add an optional opNode.extraSaves hook: when set, Forward registers the returned intermediates with the Saver (graph.SaverAware, ADR 006) alongside the output, so they survive the reset (the same mechanism layerNorm/groupNorm/adaLN/timestepEmbed already use). CrossAttention sets it to pin attn + Q/K/V. gradcheck (no reset between passes) is unaffected. Unblocks ztensor main CI. Unrelated to the bf16 broadcast/scalar change in this branch's other commit; folded in to restore a green main in one cycle. --- testing/gradcheck/ops.go | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/testing/gradcheck/ops.go b/testing/gradcheck/ops.go index fc42159..98b9191 100644 --- a/testing/gradcheck/ops.go +++ b/testing/gradcheck/ops.go @@ -41,6 +41,13 @@ type opNode[T tensor.Float] struct { bwd func(ctx context.Context, g tn[T], inputs []tn[T], out tn[T]) ([]tn[T], error) out tn[T] // cached forward output, saved for backward saver graph.Saver[T] + // extraSaves, when set, returns forward intermediates (beyond the output) + // that Backward reads via closure capture. They are registered with the + // Saver so they survive an arena Reset between Forward and Backward (the + // testing/parity reset-between-fwd-bwd schedule); without this they are + // arena-freed and Backward reads garbage (max_abs=+Inf). Evaluated after + // fwd has populated the captured variables. + extraSaves func() []tn[T] } func (n *opNode[T]) OpType() string { return n.opType } @@ -61,6 +68,9 @@ func (n *opNode[T]) Forward(ctx context.Context, inputs ...tn[T]) (tn[T], error) n.out = y if n.saver != nil { n.saver.SaveForBackward(y) + if n.extraSaves != nil { + n.saver.SaveForBackward(n.extraSaves()...) + } } return y, nil } @@ -886,9 +896,12 @@ func (n *groupNormNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g // 1/sqrt(E), E = query last dim = d -- matches). No trainable parameters; the // three inputs are Q, K, V and the gradient flows to all three. func newCrossAttentionNode[T tensor.Float](e compute.Engine[T]) *opNode[T] { - // Intermediates captured across Forward/Backward. gradcheck builds a fresh - // node per evaluation and never resets an arena between the passes, so - // closure capture is sufficient (no Saver needed -- see the file header). + // Intermediates captured across Forward/Backward via closure. Backward reads + // attn (softmax weights) and the Q/K/V inputs; under the testing/parity + // reset-between-fwd-bwd schedule the arena is Reset between the passes, so + // these must be pinned via the Saver (extraSaves below) -- otherwise they are + // arena-freed and Backward reads garbage (max_abs=+Inf). gradcheck itself + // never resets between passes, so it is unaffected either way. var ( attn tn[T] // softmax weights A [Lq, Lk] qIn tn[T] @@ -898,6 +911,9 @@ func newCrossAttentionNode[T tensor.Float](e compute.Engine[T]) *opNode[T] { ) return &opNode[T]{ opType: "CrossAttention", + extraSaves: func() []tn[T] { + return []tn[T]{attn, qIn, kIn, vIn} + }, fwd: func(ctx context.Context, in []tn[T]) (tn[T], error) { if len(in) != 3 { return nil, fmt.Errorf("CrossAttention: want 3 inputs (Q,K,V), got %d", len(in)) @@ -1151,9 +1167,9 @@ type timestepEmbedNode[T tensor.Float] struct { freqs *graph.Parameter[T] // [1, H] half int // H - sinv tn[T] // sin(arg) [N, H] - cosv tn[T] // cos(arg) [N, H] - tIn tn[T] + sinv tn[T] // sin(arg) [N, H] + cosv tn[T] // cos(arg) [N, H] + tIn tn[T] saver graph.Saver[T] }