Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions compute/gpu_bf16.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,122 @@ func gpuBinaryOpBF16[T tensor.Numeric](
return makeGPUResult[T](e, a.Shape(), devC, n, dst...)
}

// execBroadcast2D runs a 2D broadcast binary kernel and builds the result.
//
// For f32 it calls the kernel directly (prior behavior). For bf16 -- which has
// no native broadcast kernel -- it converts both operands to f32 on-device,
// runs the existing f32 broadcast kernel, then converts the result back to bf16,
// all on the engine stream. This keeps bf16 broadcast ops on the GPU and
// CUDA-graph-capturable, instead of the host CPU fallback whose D2H/H2D copies
// break stream capture (e.g. QKL2Norm's Mul(x, inv)). Computing in f32 also
// matches the bf16 GEMM/reduction convention (f32 accumulation, bf16 storage).
// devA/devB are T-typed device pointers; nA/nB are the element counts of a/b.
func execBroadcast2D[T tensor.Numeric](
e *GPUEngine[T], outShape []int,
devA unsafe.Pointer, nA int, devB unsafe.Pointer, nB int, outElems int,
saRow, saCol, sbRow, sbCol, mDim, dDim int,
kernelFn func(devA, devB, devC unsafe.Pointer, saRow, saCol, sbRow, sbCol, M, D int, stream gpuapi.Stream) error,
dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error) {
if !isBFloat16[T]() {
devC, err := e.pool.Alloc(e.deviceID, outElems*f32Size)
if err != nil {
return nil, err
}
if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, mDim, dDim, e.stream); err != nil {
e.pool.Free(e.deviceID, devC, outElems*f32Size)
return nil, err
}
return makeGPUResult[T](e, outShape, devC, outElems, dst...)
}

// bf16: a,b -> f32 scratch, f32 broadcast, result -> bf16. Scratch is freed
// after the kernels are enqueued on the engine stream (arena frees are
// stream-ordered, matching the getDevicePtr FP16->F32 scratch pattern).
aF, err := e.pool.Alloc(e.deviceID, nA*f32Size)
if err != nil {
return nil, err
}
defer e.pool.Free(e.deviceID, aF, nA*f32Size)
if err := e.kernels.BF16ToF32(devA, aF, nA, e.stream); err != nil {
return nil, err
}
bF, err := e.pool.Alloc(e.deviceID, nB*f32Size)
if err != nil {
return nil, err
}
defer e.pool.Free(e.deviceID, bF, nB*f32Size)
if err := e.kernels.BF16ToF32(devB, bF, nB, e.stream); err != nil {
return nil, err
}
cF, err := e.pool.Alloc(e.deviceID, outElems*f32Size)
if err != nil {
return nil, err
}
defer e.pool.Free(e.deviceID, cF, outElems*f32Size)
if err := kernelFn(aF, bF, cF, saRow, saCol, sbRow, sbCol, mDim, dDim, e.stream); err != nil {
return nil, err
}
cB, err := e.pool.Alloc(e.deviceID, outElems*bf16Size)
if err != nil {
return nil, err
}
if err := e.kernels.F32ToBF16(cF, cB, outElems, e.stream); err != nil {
e.pool.Free(e.deviceID, cB, outElems*bf16Size)
return nil, err
}
return makeGPUResult[T](e, outShape, cB, outElems, dst...)
}

// bf16ScalarToF32 converts a bf16 scalar (carried as T) to float32. Used by the
// bf16 scalar-op path, where toFloat32 (an any.(float32) assertion) would panic.
func bf16ScalarToF32[T tensor.Numeric](v T) float32 {
return any(v).(float16.BFloat16).ToFloat32()
}

// gpuScalarOpBF16 runs a scalar kernel (c = op(a, scalar)) for bf16 by converting
// a to f32 on-device, running the f32 scalar kernel, and converting the result
// back to bf16 -- keeping the op on the GPU and capture-safe (the CPU fallback's
// host copies break CUDA-graph capture, e.g. QKL2Norm's AddScalar(eps)).
func gpuScalarOpBF16[T tensor.Numeric](
e *GPUEngine[T], a *tensor.TensorNumeric[T], scalar float32,
kernelFn func(devA unsafe.Pointer, scalar float32, devC unsafe.Pointer, n int, stream gpuapi.Stream) error,
dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error) {
e.setDevice()
n := a.GetStorage().Len()
devA, cleanupA, err := getDevicePtr(e, a)
if err != nil {
return nil, err
}
defer cleanupA()
aF, err := e.pool.Alloc(e.deviceID, n*f32Size)
if err != nil {
return nil, err
}
defer e.pool.Free(e.deviceID, aF, n*f32Size)
if err := e.kernels.BF16ToF32(devA, aF, n, e.stream); err != nil {
return nil, err
}
cF, err := e.pool.Alloc(e.deviceID, n*f32Size)
if err != nil {
return nil, err
}
defer e.pool.Free(e.deviceID, cF, n*f32Size)
if err := kernelFn(aF, scalar, cF, n, e.stream); err != nil {
return nil, err
}
cB, err := e.pool.Alloc(e.deviceID, n*bf16Size)
if err != nil {
return nil, err
}
if err := e.kernels.F32ToBF16(cF, cB, n, e.stream); err != nil {
e.pool.Free(e.deviceID, cB, n*bf16Size)
return nil, err
}
return makeGPUResult[T](e, a.Shape(), cB, n, dst...)
}

// gpuSoftmaxBF16 runs a native bf16 softmax along the given axis using the
// fused scaled-softmax kernel (scale = 1.0) with FP32 max/sum accumulation.
func gpuSoftmaxBF16[T tensor.Numeric](
Expand Down
44 changes: 44 additions & 0 deletions compute/gpu_bf16_transpose_gemm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,50 @@ func TestGPUBF16_ReshapeStaysOnDevice(t *testing.T) {
}
}

func TestGPUBF16_BroadcastAndScalarParity(t *testing.T) {
eng := newTestGPUBF16Engine(t)
ctx := context.Background()

// These are the QKL2Norm ops that broke CUDA-graph capture by falling to the
// CPU engine: a column-broadcast Mul (x[M,D] * inv[M,1]) and AddScalar.
const M, D = 6, 4
xv := ramp(M * D)
invv := make([]float32, M) // [M,1] column vector
for i := range invv {
invv[i] = float32((i%5)+1) * 0.25 // bf16-exact
}
x := bf16Tensor(t, []int{M, D}, xv)
inv := bf16Tensor(t, []int{M, 1}, invv)

// Broadcast Mul: out[i,j] = x[i,j] * inv[i].
got, err := eng.Mul(ctx, x, inv)
if err != nil {
t.Fatalf("broadcast Mul: %v", err)
}
if want := []int{M, D}; !shapeEq(got.Shape(), want) {
t.Fatalf("broadcast Mul shape = %v, want %v", got.Shape(), want)
}
gd := bf16ToF32(got.Data())
for i := 0; i < M; i++ {
for j := 0; j < D; j++ {
want := float16.BFloat16FromFloat32(xv[i*D+j] * invv[i]).ToFloat32()
assertBF16Close(t, "broadcastMul", i*D+j, gd[i*D+j], want, 2.0)
}
}

// AddScalar: out[i] = x[i] + eps.
eps := float16.BFloat16FromFloat32(0.125)
gotS, err := eng.AddScalar(ctx, x, eps)
if err != nil {
t.Fatalf("AddScalar: %v", err)
}
sd := bf16ToF32(gotS.Data())
for i := range xv {
want := float16.BFloat16FromFloat32(xv[i] + 0.125).ToFloat32()
assertBF16Close(t, "addScalar", i, sd[i], want, 2.0)
}
}

// ramp returns n small, bf16-exact values centered near zero so GEMM sums stay
// well-conditioned (no catastrophic cancellation).
func ramp(n int) []float32 {
Expand Down
51 changes: 24 additions & 27 deletions compute/gpu_kernels.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ func gpuBroadcast4DOp[T tensor.Numeric](
if !ok {
return nil, nil // signal: not handled
}
// The f32 4D broadcast kernel cannot run on bf16 (2-byte) device pointers.
// Signal "not handled" so the caller takes the CPU fallback. (The B=1
// CrossAsset graph uses only 2D broadcasts, handled on-device via
// execBroadcast2D; bf16 4D broadcast support is a follow-up.)
if isBFloat16[T]() {
return nil, nil
}

outShape := broadcastShape(aShape, bShape)
outElems := 1
Expand Down Expand Up @@ -485,18 +492,8 @@ func gpuBroadcastOp[T tensor.Numeric](
defer cleanupB()

outElems := M * D
byteSize := outElems * f32Size
devC, err := e.pool.Alloc(e.deviceID, byteSize)
if err != nil {
return nil, err
}

if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, M, D, e.stream); err != nil {
e.pool.Free(e.deviceID, devC, byteSize)
return nil, err
}

return makeGPUResult[T](e, outShape, devC, outElems, dst...)
return execBroadcast2D(e, outShape, devA, aTotal, devB, bTotal,
outElems, saRow, saCol, sbRow, sbCol, M, D, kernelFn, dst...)
}
}

Expand Down Expand Up @@ -597,18 +594,9 @@ func gpuBroadcastOp[T tensor.Numeric](
return nil, err
}
defer cleanupB()
byteSize := outElems * f32Size
devC, err := e.pool.Alloc(e.deviceID, byteSize)
if err != nil {
return nil, err
}

if err := kernelFn(devA, devB, devC, saRow, saCol, sbRow, sbCol, M, D, e.stream); err != nil {
e.pool.Free(e.deviceID, devC, byteSize)
return nil, err
}

return makeGPUResult[T](e, outShape, devC, outElems, dst...)
return execBroadcast2D(e, outShape, devA, totalElements(aShape), devB, totalElements(bShape),
outElems, saRow, saCol, sbRow, sbCol, M, D, kernelFn, dst...)
}

// flattenTo2D flattens an N-D shape to [M, D] where M = product of all dims except last, D = last dim.
Expand Down Expand Up @@ -810,7 +798,7 @@ func (e *GPUEngine[T]) gpuAdd(ctx context.Context, a, b *tensor.TensorNumeric[T]
if sameShape(a, b) {
return gpuBinaryOpBF16(e, a, b, e.kernels.AddBF16, dst...)
}
return e.cpu.Add(ctx, a, b, dst...)
return gpuBroadcastOp(e, ctx, a, b, e.kernels.AddBroadcast, e.kernels.AddBroadcast4D, e.cpu.Add, dst...)
}
if !isFloat32[T]() {
return e.cpu.Add(ctx, a, b, dst...)
Expand All @@ -835,7 +823,7 @@ func (e *GPUEngine[T]) gpuSub(ctx context.Context, a, b *tensor.TensorNumeric[T]
if sameShape(a, b) {
return gpuBinaryOpBF16(e, a, b, e.kernels.SubBF16, dst...)
}
return e.cpu.Sub(ctx, a, b, dst...)
return gpuBroadcastOp(e, ctx, a, b, e.kernels.SubBroadcast, e.kernels.SubBroadcast4D, e.cpu.Sub, dst...)
}
if !isFloat32[T]() {
return e.cpu.Sub(ctx, a, b, dst...)
Expand All @@ -860,7 +848,7 @@ func (e *GPUEngine[T]) gpuMul(ctx context.Context, a, b *tensor.TensorNumeric[T]
if sameShape(a, b) {
return gpuBinaryOpBF16(e, a, b, e.kernels.MulBF16, dst...)
}
return e.cpu.Mul(ctx, a, b, dst...)
return gpuBroadcastOp(e, ctx, a, b, e.kernels.MulBroadcast, e.kernels.MulBroadcast4D, e.cpu.Mul, dst...)
}
if !isFloat32[T]() {
return e.cpu.Mul(ctx, a, b, dst...)
Expand All @@ -885,7 +873,7 @@ func (e *GPUEngine[T]) gpuDiv(ctx context.Context, a, b *tensor.TensorNumeric[T]
if sameShape(a, b) {
return gpuBinaryOpBF16(e, a, b, e.kernels.DivBF16, dst...)
}
return e.cpu.Div(ctx, a, b, dst...)
return gpuBroadcastOp(e, ctx, a, b, e.kernels.DivBroadcast, e.kernels.DivBroadcast4D, e.cpu.Div, dst...)
}
if !isFloat32[T]() {
return e.cpu.Div(ctx, a, b, dst...)
Expand Down Expand Up @@ -1045,6 +1033,9 @@ func (e *GPUEngine[T]) gpuTanhPrime(ctx context.Context, a, upstream *tensor.Ten
}

func (e *GPUEngine[T]) gpuAddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
if isBFloat16[T]() {
return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.AddScalar, dst...)
}
if !isFloat32[T]() {
return e.cpu.AddScalar(ctx, a, scalar, dst...)
}
Expand All @@ -1055,6 +1046,9 @@ func (e *GPUEngine[T]) gpuAddScalar(ctx context.Context, a *tensor.TensorNumeric
}

func (e *GPUEngine[T]) gpuMulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
if isBFloat16[T]() {
return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.MulScalar, dst...)
}
if !isFloat32[T]() {
return e.cpu.MulScalar(ctx, a, scalar, dst...)
}
Expand All @@ -1065,6 +1059,9 @@ func (e *GPUEngine[T]) gpuMulScalar(ctx context.Context, a *tensor.TensorNumeric
}

func (e *GPUEngine[T]) gpuDivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
if isBFloat16[T]() {
return gpuScalarOpBF16(e, a, bf16ScalarToF32(scalar), e.kernels.DivScalar, dst...)
}
if !isFloat32[T]() {
return e.cpu.DivScalar(ctx, a, scalar, dst...)
}
Expand Down
28 changes: 22 additions & 6 deletions testing/gradcheck/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ type opNode[T tensor.Float] struct {
bwd func(ctx context.Context, g tn[T], inputs []tn[T], out tn[T]) ([]tn[T], error)
out tn[T] // cached forward output, saved for backward
saver graph.Saver[T]
// extraSaves, when set, returns forward intermediates (beyond the output)
// that Backward reads via closure capture. They are registered with the
// Saver so they survive an arena Reset between Forward and Backward (the
// testing/parity reset-between-fwd-bwd schedule); without this they are
// arena-freed and Backward reads garbage (max_abs=+Inf). Evaluated after
// fwd has populated the captured variables.
extraSaves func() []tn[T]
}

func (n *opNode[T]) OpType() string { return n.opType }
Expand All @@ -61,6 +68,9 @@ func (n *opNode[T]) Forward(ctx context.Context, inputs ...tn[T]) (tn[T], error)
n.out = y
if n.saver != nil {
n.saver.SaveForBackward(y)
if n.extraSaves != nil {
n.saver.SaveForBackward(n.extraSaves()...)
}
}
return y, nil
}
Expand Down Expand Up @@ -886,9 +896,12 @@ func (n *groupNormNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g
// 1/sqrt(E), E = query last dim = d -- matches). No trainable parameters; the
// three inputs are Q, K, V and the gradient flows to all three.
func newCrossAttentionNode[T tensor.Float](e compute.Engine[T]) *opNode[T] {
// Intermediates captured across Forward/Backward. gradcheck builds a fresh
// node per evaluation and never resets an arena between the passes, so
// closure capture is sufficient (no Saver needed -- see the file header).
// Intermediates captured across Forward/Backward via closure. Backward reads
// attn (softmax weights) and the Q/K/V inputs; under the testing/parity
// reset-between-fwd-bwd schedule the arena is Reset between the passes, so
// these must be pinned via the Saver (extraSaves below) -- otherwise they are
// arena-freed and Backward reads garbage (max_abs=+Inf). gradcheck itself
// never resets between passes, so it is unaffected either way.
var (
attn tn[T] // softmax weights A [Lq, Lk]
qIn tn[T]
Expand All @@ -898,6 +911,9 @@ func newCrossAttentionNode[T tensor.Float](e compute.Engine[T]) *opNode[T] {
)
return &opNode[T]{
opType: "CrossAttention",
extraSaves: func() []tn[T] {
return []tn[T]{attn, qIn, kIn, vIn}
},
fwd: func(ctx context.Context, in []tn[T]) (tn[T], error) {
if len(in) != 3 {
return nil, fmt.Errorf("CrossAttention: want 3 inputs (Q,K,V), got %d", len(in))
Expand Down Expand Up @@ -1151,9 +1167,9 @@ type timestepEmbedNode[T tensor.Float] struct {
freqs *graph.Parameter[T] // [1, H]
half int // H

sinv tn[T] // sin(arg) [N, H]
cosv tn[T] // cos(arg) [N, H]
tIn tn[T]
sinv tn[T] // sin(arg) [N, H]
cosv tn[T] // cos(arg) [N, H]
tIn tn[T]
saver graph.Saver[T]
}

Expand Down
Loading