From c1e6e58db1c293009d442fb3921a527f994c86d8 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Tue, 16 Jun 2026 23:05:26 -0700 Subject: [PATCH 1/2] test(oracle): add GroupNorm to gradcheck registry + torch oracle Extends the ADR-091 PyTorch-oracle / gradcheck harness to the GroupNorm op class (zerfoo E127/T127.1.0a, first of six new diffusion-DiT op classes). GroupNorm composes entirely from existing engine reduce/elementwise ops: reshape [N,C] -> [N*groups, C/groups], normalize the last axis exactly like the LayerNorm node, reshape back, apply a per-channel affine. No new engine kernel. Adds the node (gradcheck/ops.go), the registry entry + dispatch (registry.go, dim=4 groups=2), and the torch replay + tolerance (torchmap.go, torch.nn.functional.group_norm). Verified: TestRegistry/GroupNorm gradcheck passes (analytic backward vs finite-difference); full gradcheck + oracle registry<->torchmap lockstep green. Unlocks the convolutional-VAE/UNet GroupNorm primitive for the diffusion class. --- testing/gradcheck/ops.go | 213 ++++++++++++++++++++++++++++++++++ testing/gradcheck/registry.go | 10 ++ testing/oracle/torchmap.go | 6 + 3 files changed, 229 insertions(+) diff --git a/testing/gradcheck/ops.go b/testing/gradcheck/ops.go index bab4d11..4e2d2c9 100644 --- a/testing/gradcheck/ops.go +++ b/testing/gradcheck/ops.go @@ -656,3 +656,216 @@ func (n *layerNormNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g } return []tn[T]{dx}, nil } + +// --- groupnorm (per-group normalize + per-channel affine) --------------------- + +// groupNormNode normalizes a 2D input [N, C] in `groups` channel-groups and +// then applies a trainable per-channel affine (gamma, beta of shape [1, C]). +// It reshapes [N, C] -> [N*groups, C/groups], normalizes the last axis exactly +// like layerNormNode (so the per-group statistics fall out of the same engine +// reduce/elementwise ops), and reshapes back -- needing NO new engine kernel. +// This is the canonical convolutional-VAE/UNet normalization +// (torch.nn.functional.group_norm); landing it here extends the ADR-091 +// PyTorch-oracle harness to the GroupNorm op class (E127/T127.1.0a) and +// unlocks the whole Stable-Diffusion-family VAE/UNet primitive set. +type groupNormNode[T tensor.Float] struct { + engine compute.Engine[T] + gamma *graph.Parameter[T] + beta *graph.Parameter[T] + groups int + eps float64 + + xhatR tn[T] // normalized input in grouped shape [N*groups, C/groups] + inv tn[T] // inverse stddev per group, [N*groups, 1] + nRows int // N + chans int // C + saver graph.Saver[T] +} + +func newGroupNormNode[T tensor.Float](e compute.Engine[T], dim, groups int) (*groupNormNode[T], error) { + if groups <= 0 || dim%groups != 0 { + return nil, fmt.Errorf("GroupNorm: dim %d not divisible by groups %d", dim, groups) + } + gammaData := make([]T, dim) + betaData := make([]T, dim) + for i := 0; i < dim; i++ { + // Deterministic, non-uniform initial values so parameter gradients + // are structurally informative (mirrors newLayerNormNode). + gammaData[i] = T(0.8 + 0.1*float64(i)) + betaData[i] = T(-0.2 + 0.15*float64(i)) + } + gv, err := newTensorOf([]int{1, dim}, gammaData) + if err != nil { + return nil, err + } + bv, err := newTensorOf([]int{1, dim}, betaData) + if err != nil { + return nil, err + } + gamma, err := graph.NewParameter[T]("gamma", gv, newTensorOf[T]) + if err != nil { + return nil, err + } + beta, err := graph.NewParameter[T]("beta", bv, newTensorOf[T]) + if err != nil { + return nil, err + } + return &groupNormNode[T]{engine: e, gamma: gamma, beta: beta, groups: groups, eps: 1e-5}, nil +} + +func (n *groupNormNode[T]) OpType() string { return "GroupNorm" } +func (n *groupNormNode[T]) Attributes() map[string]interface{} { + return map[string]interface{}{"epsilon": n.eps, "groups": n.groups} +} +func (n *groupNormNode[T]) Parameters() []*graph.Parameter[T] { + return []*graph.Parameter[T]{n.gamma, n.beta} +} +func (n *groupNormNode[T]) SetSaver(s graph.Saver[T]) { n.saver = s } + +func (n *groupNormNode[T]) OutputShape() []int { + if n.xhatR == nil { + return nil + } + return []int{n.nRows, n.chans} +} + +func (n *groupNormNode[T]) Forward(ctx context.Context, inputs ...tn[T]) (tn[T], error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("GroupNorm: want 1 input, got %d", len(inputs)) + } + x := inputs[0] + shape := x.Shape() + if len(shape) != 2 { + return nil, fmt.Errorf("GroupNorm: want 2D input [N, C], got %v", shape) + } + n.nRows, n.chans = shape[0], shape[1] + if n.chans%n.groups != 0 { + return nil, fmt.Errorf("GroupNorm: C %d not divisible by groups %d", n.chans, n.groups) + } + gw := n.chans / n.groups + e := n.engine + // Reshape [N, C] -> [N*groups, C/groups]; normalize the last axis. + xr, err := e.Reshape(ctx, x, []int{n.nRows * n.groups, gw}) + if err != nil { + return nil, err + } + mean, err := e.ReduceMean(ctx, xr, 1, true) + if err != nil { + return nil, err + } + xc, err := e.Sub(ctx, xr, mean) + if err != nil { + return nil, err + } + sq, err := e.Mul(ctx, xc, xc) + if err != nil { + return nil, err + } + variance, err := e.ReduceMean(ctx, sq, 1, true) + if err != nil { + return nil, err + } + veps, err := e.AddScalar(ctx, variance, T(n.eps)) + if err != nil { + return nil, err + } + inv, err := e.Rsqrt(ctx, veps) + if err != nil { + return nil, err + } + xhatR, err := e.Mul(ctx, xc, inv) + if err != nil { + return nil, err + } + n.xhatR = xhatR + n.inv = inv + if n.saver != nil { + n.saver.SaveForBackward(xhatR, inv) + } + // Reshape back to [N, C] and apply the per-channel affine. + xhat, err := e.Reshape(ctx, xhatR, []int{n.nRows, n.chans}) + if err != nil { + return nil, err + } + scaled, err := e.Mul(ctx, xhat, n.gamma.Value) + if err != nil { + return nil, err + } + return e.Add(ctx, scaled, n.beta.Value) +} + +func (n *groupNormNode[T]) Backward(ctx context.Context, _ types.BackwardMode, g tn[T], _ ...tn[T]) ([]tn[T], error) { + if n.xhatR == nil || n.inv == nil { + return nil, errors.New("GroupNorm: Backward called before Forward") + } + e := n.engine + gw := n.chans / n.groups + // xhat in [N, C] form for the per-channel parameter gradients. + xhat, err := e.Reshape(ctx, n.xhatR, []int{n.nRows, n.chans}) + if err != nil { + return nil, err + } + // dGamma = sum_batch(g * xhat); dBeta = sum_batch(g). Both [1, C]. + gx, err := e.Mul(ctx, g, xhat) + if err != nil { + return nil, err + } + dgamma, err := e.ReduceSum(ctx, gx, 0, true) + if err != nil { + return nil, err + } + if err := n.gamma.AddGradient(dgamma); err != nil { + return nil, err + } + dbeta, err := e.ReduceSum(ctx, g, 0, true) + if err != nil { + return nil, err + } + if err := n.beta.AddGradient(dbeta); err != nil { + return nil, err + } + // dxhat = g * gamma in [N, C]; reshape to grouped form so the per-group + // normalization backward (means over the C/groups axis) reuses the exact + // layerNorm dX formula. + dxhat, err := e.Mul(ctx, g, n.gamma.Value) + if err != nil { + return nil, err + } + dxhatR, err := e.Reshape(ctx, dxhat, []int{n.nRows * n.groups, gw}) + if err != nil { + return nil, err + } + m1, err := e.ReduceMean(ctx, dxhatR, 1, true) + if err != nil { + return nil, err + } + dxx, err := e.Mul(ctx, dxhatR, n.xhatR) + if err != nil { + return nil, err + } + m2, err := e.ReduceMean(ctx, dxx, 1, true) + if err != nil { + return nil, err + } + t1, err := e.Sub(ctx, dxhatR, m1) + if err != nil { + return nil, err + } + xm2, err := e.Mul(ctx, n.xhatR, m2) + if err != nil { + return nil, err + } + t2, err := e.Sub(ctx, t1, xm2) + if err != nil { + return nil, err + } + dxR, err := e.Mul(ctx, t2, n.inv) + if err != nil { + return nil, err + } + dx, err := e.Reshape(ctx, dxR, []int{n.nRows, n.chans}) + if err != nil { + return nil, err + } + return []tn[T]{dx}, nil +} diff --git a/testing/gradcheck/registry.go b/testing/gradcheck/registry.go index fd58075..049be55 100644 --- a/testing/gradcheck/registry.go +++ b/testing/gradcheck/registry.go @@ -67,6 +67,8 @@ func NewRegistryNode[T tensor.Float](name string, e compute.Engine[T]) (graph.No return newReduceMaxNode(e), nil case "LayerNorm": return newLayerNormNode(e, 4) + case "GroupNorm": + return newGroupNormNode(e, 4, 2) default: return nil, fmt.Errorf("gradcheck: no registry op named %q", name) } @@ -239,5 +241,13 @@ func Registry() []OpInfo { Make: registryMake("LayerNorm"), InputShapes: [][]int{{3, 4}}, }, + // GroupNorm (per-group normalize + per-channel affine); dim=4, groups=2 + // so the [3,4] input reshapes to [6,2] groups. Canonical VAE/UNet norm + // (E127/T127.1.0a -- extends the oracle to the GroupNorm op class). + { + Name: "GroupNorm", Seed: 27, + Make: registryMake("GroupNorm"), + InputShapes: [][]int{{3, 4}}, + }, } } diff --git a/testing/oracle/torchmap.go b/testing/oracle/torchmap.go index c46abe7..edd8c72 100644 --- a/testing/oracle/torchmap.go +++ b/testing/oracle/torchmap.go @@ -62,6 +62,11 @@ var torchMap = map[string]torchOp{ // expression reshapes the leaf inside the graph so torch records the // gradient on the (1, dim) leaf, matching the recorded ztensor shapes. "LayerNorm": {Expr: "torch.nn.functional.layer_norm(x0, (4,), weight=gamma.reshape(4), bias=beta.reshape(4), eps=1e-05)"}, + + // GroupNorm: 2D input [N, C] with num_groups=2 over C=4 (groups of 2), + // per-channel affine. Matches gradcheck.newGroupNormNode(e, 4, 2); gamma/beta + // leaves stay (1, 4) and reshape to (4,) inside the graph, like LayerNorm. + "GroupNorm": {Expr: "torch.nn.functional.group_norm(x0, 2, weight=gamma.reshape(4), bias=beta.reshape(4), eps=1e-05)"}, } // defaultTolerance is the first-cut f32 comparison bar: ztensor CPU/GPU f32 @@ -80,6 +85,7 @@ var defaultTolerance = Tolerance{ var toleranceOverrides = map[string]Tolerance{ "Softmax": {FwdAtol: 1e-6, FwdRtol: 1e-4, GradAtol: 1e-6, GradRtol: 1e-3}, "LayerNorm": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, + "GroupNorm": {FwdAtol: 1e-5, FwdRtol: 1e-4, GradAtol: 1e-5, GradRtol: 1e-3}, "MatMul": {FwdAtol: 1e-6, FwdRtol: 1e-4, GradAtol: 1e-6, GradRtol: 1e-3}, } From 73d18120bca058a778b90de7c3864056f623ac6c Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Tue, 16 Jun 2026 23:37:39 -0700 Subject: [PATCH 2/2] feat(compute): native bf16 GPU transpose kernels (capture-safe) GPUEngine.Transpose routed every non-float32 type to the CPU engine, whose host memcpy breaks CUDA-graph capture -- so any bf16 transpose under capture failed ("operation would make the legacy stream depend on a capturing blocking stream", e.g. node QKL2Norm's Transpose). This forced the bf16 CrossAsset GPU bench to run with capture DISABLED (~190 s/epoch). Add native bf16 (16-bit) transpose kernels. Transpose is pure data movement, so the kernels operate on `unsigned short` -- a bitwise element copy independent of the bf16 numeric interpretation (no bf16 math, no new headers): - transpose.cu: kernel_transpose_2d_bf16 + kernel_transpose_nd_bf16 (+ launchers) - cuda/kernels: Transpose2DBF16 / TransposeNDBF16 (cuda + purego builds) and dlopen symbol registration - gpuapi: optional BFloat16Transposer extension + CUDAKernels impl - GPUEngine.Transpose: for bf16 with GPU-resident input and a backend that implements BFloat16Transposer, transpose on-device via the bf16 kernels; otherwise fall back to the CPU engine as before. The f32 path is byte-for-byte unchanged (element size threaded only through the byteSize/kernel-select). CUDA-gated parity tests: 2D + 3D[0,2,1] (the QKL2Norm shape), exact-match. Lets the bf16 CrossAsset GPU bench run with CUDA-graph capture ON. ADR-075 L4. --- compute/gpu_bf16_transpose_gemm_test.go | 59 +++++++++++++++++++ compute/gpu_engine_memory.go | 40 +++++++++++-- internal/cuda/kernels/purego.go | 5 +- internal/cuda/kernels/transpose.cu | 72 +++++++++++++++++++++++ internal/cuda/kernels/transpose.go | 29 +++++++++ internal/cuda/kernels/transpose_purego.go | 31 ++++++++++ internal/gpuapi/cuda_kernels.go | 9 +++ internal/gpuapi/kernels.go | 12 ++++ 8 files changed, 250 insertions(+), 7 deletions(-) diff --git a/compute/gpu_bf16_transpose_gemm_test.go b/compute/gpu_bf16_transpose_gemm_test.go index b02ebaf..0c6d47f 100644 --- a/compute/gpu_bf16_transpose_gemm_test.go +++ b/compute/gpu_bf16_transpose_gemm_test.go @@ -134,6 +134,65 @@ func TestGPUBF16_MatMulTransposeBBatchedParity(t *testing.T) { } } +func TestGPUBF16_TransposeParity(t *testing.T) { + eng := newTestGPUBF16Engine(t) + ctx := context.Background() + + // bf16 transpose is a pure bitwise element move (no arithmetic), so the + // output values must match the reference EXACTLY. This exercises the native + // bf16 GPU transpose kernels (Transpose2DBF16/NDBF16) that keep bf16 + // transposes on-device (CUDA-graph-capturable) instead of the CPU fallback. + + // 2D: [rows, cols] -> [cols, rows] via axes [1,0]. + t.Run("2D", func(t *testing.T) { + const rows, cols = 5, 7 + vals := ramp(rows * cols) + x := bf16Tensor(t, []int{rows, cols}, vals) + got, err := eng.Transpose(ctx, x, []int{1, 0}) + if err != nil { + t.Fatalf("Transpose 2D: %v", err) + } + if want := []int{cols, rows}; !shapeEq(got.Shape(), want) { + t.Fatalf("2D transpose shape = %v, want %v", got.Shape(), want) + } + gd := bf16ToF32(got.Data()) + for i := 0; i < cols; i++ { + for j := 0; j < rows; j++ { + want := vals[j*cols+i] // input[j,i] -> output[i,j] + if gd[i*rows+j] != want { + t.Fatalf("2D[%d,%d] = %g, want %g", i, j, gd[i*rows+j], want) + } + } + } + }) + + // 3D: [d0,d1,d2] -> [d0,d2,d1] via axes [0,2,1] (the QKL2Norm-style case). + t.Run("3D_021", func(t *testing.T) { + const d0, d1, d2 = 4, 12, 64 + vals := ramp(d0 * d1 * d2) + x := bf16Tensor(t, []int{d0, d1, d2}, vals) + got, err := eng.Transpose(ctx, x, []int{0, 2, 1}) + if err != nil { + t.Fatalf("Transpose 3D: %v", err) + } + if want := []int{d0, d2, d1}; !shapeEq(got.Shape(), want) { + t.Fatalf("3D transpose shape = %v, want %v", got.Shape(), want) + } + gd := bf16ToF32(got.Data()) + for a := 0; a < d0; a++ { + for c := 0; c < d2; c++ { + for b := 0; b < d1; b++ { + want := vals[a*d1*d2+b*d2+c] // input[a,b,c] + gotV := gd[a*d2*d1+c*d1+b] // output[a,c,b] + if gotV != want { + t.Fatalf("3D[%d,%d,%d] = %g, want %g", a, c, b, gotV, want) + } + } + } + } + }) +} + // ramp returns n small, bf16-exact values centered near zero so GEMM sums stay // well-conditioned (no catastrophic cancellation). func ramp(n int) []float32 { diff --git a/compute/gpu_engine_memory.go b/compute/gpu_engine_memory.go index 6c352bb..42c62af 100644 --- a/compute/gpu_engine_memory.go +++ b/compute/gpu_engine_memory.go @@ -16,8 +16,20 @@ import ( // Transpose transposes a tensor along the given axes. func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + // bf16 stays on-device when the kernel backend provides native bf16 transpose + // kernels AND the input is GPU-resident; otherwise (and for every non-f32, + // non-bf16 type) fall back to the CPU engine. Keeping bf16 transposes on the + // device is required for CUDA-graph capture: the CPU fallback's host memcpy + // breaks stream capture (e.g. QKL2Norm's transpose). ADR-075 lever L4. + var bf16Transposer gpuapi.BFloat16Transposer + bf16Path := false if !isFloat32[T]() { - return e.cpu.Transpose(ctx, a, axes, dst...) + bt, ok := e.kernels.(gpuapi.BFloat16Transposer) + _, isGPUStore := a.GetStorage().(*tensor.GPUStorage[T]) + if !isBFloat16[T]() || !ok || !isGPUStore { + return e.cpu.Transpose(ctx, a, axes, dst...) + } + bf16Transposer, bf16Path = bt, true } // Only use GPU path for GPU-resident tensors (Phase 6 behavior). @@ -131,7 +143,11 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T] fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr OK: ptr=%p\n", devIn) } - byteSize := total * f32Size + elemSize := f32Size + if bf16Path { + elemSize = 2 // bf16 is 16-bit + } + byteSize := total * elemSize // Reuse dst's existing GPU memory when possible (#84). devOut, reused := tryReuseDstPtr[T](total, dst) @@ -149,11 +165,17 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T] "rows", fmt.Sprintf("%d", shape[0]), "cols", fmt.Sprintf("%d", shape[1])) } - if err := e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream); err != nil { + var t2dErr error + if bf16Path { + t2dErr = bf16Transposer.Transpose2DBF16(devIn, devOut, shape[0], shape[1], e.stream) + } else { + t2dErr = e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream) + } + if t2dErr != nil { if !reused { e.pool.Free(e.deviceID, devOut, byteSize) } - return nil, err + return nil, t2dErr } if reused { return finishReusedDst[T](dst[0], outShape), nil @@ -208,11 +230,17 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T] devOS := unsafe.Slice((*int32)(unsafe.Add(devParams, 4*rank)), rank) devPerm := unsafe.Slice((*int32)(unsafe.Add(devParams, 8*rank)), rank) - if err := e.kernels.TransposeND(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream); err != nil { + var ndErr error + if bf16Path { + ndErr = bf16Transposer.TransposeNDBF16(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream) + } else { + ndErr = e.kernels.TransposeND(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream) + } + if ndErr != nil { if !reused { e.pool.Free(e.deviceID, devOut, byteSize) } - return nil, err + return nil, ndErr } if reused { diff --git a/internal/cuda/kernels/purego.go b/internal/cuda/kernels/purego.go index cbba14c..f2a9ebc 100644 --- a/internal/cuda/kernels/purego.go +++ b/internal/cuda/kernels/purego.go @@ -44,7 +44,8 @@ type KernelLib struct { launchGatherQ8F32 uintptr // transpose - launchTranspose2D, launchTransposeND uintptr + launchTranspose2D, launchTransposeND uintptr + launchTranspose2DBF16, launchTransposeNDBF16 uintptr // repeat launchRepeat uintptr @@ -284,6 +285,8 @@ func openKernelLib() (*KernelLib, error) { // transpose {"launch_transpose_2d", &k.launchTranspose2D}, {"launch_transpose_nd", &k.launchTransposeND}, + {"launch_transpose_2d_bf16", &k.launchTranspose2DBF16}, + {"launch_transpose_nd_bf16", &k.launchTransposeNDBF16}, // repeat {"launch_repeat", &k.launchRepeat}, // gemm_q4 diff --git a/internal/cuda/kernels/transpose.cu b/internal/cuda/kernels/transpose.cu index 1921333..e98bf68 100644 --- a/internal/cuda/kernels/transpose.cu +++ b/internal/cuda/kernels/transpose.cu @@ -65,6 +65,58 @@ __global__ void kernel_transpose_nd(const float* __restrict__ input, output[idx] = input[src_idx]; } +// ---------- bf16 transpose ---------- +// Transpose is pure data movement (no arithmetic), so the bf16 variants operate +// on 16-bit elements via `unsigned short` -- a bitwise copy independent of the +// bf16 numeric interpretation. This keeps bf16 transposes on-device so they are +// CUDA-graph-capturable (the f32-only kernels forced bf16 onto the CPU engine, +// whose host memcpy breaks stream capture). ADR-075 lever L4. + +__global__ void kernel_transpose_2d_bf16(const unsigned short* __restrict__ input, + unsigned short* __restrict__ output, + int rows, int cols) { + __shared__ unsigned short tile[TILE_DIM][TILE_DIM + 1]; + + int xIdx = blockIdx.x * TILE_DIM + threadIdx.x; + int yIdx = blockIdx.y * TILE_DIM + threadIdx.y; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if ((yIdx + j) < rows && xIdx < cols) { + tile[threadIdx.y + j][threadIdx.x] = input[(yIdx + j) * cols + xIdx]; + } + } + __syncthreads(); + + int outX = blockIdx.y * TILE_DIM + threadIdx.x; + int outY = blockIdx.x * TILE_DIM + threadIdx.y; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if ((outY + j) < cols && outX < rows) { + output[(outY + j) * rows + outX] = tile[threadIdx.x][threadIdx.y + j]; + } + } +} + +__global__ void kernel_transpose_nd_bf16(const unsigned short* __restrict__ input, + unsigned short* __restrict__ output, + const int* __restrict__ in_strides, + const int* __restrict__ out_strides, + const int* __restrict__ perm, + int ndim, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + int remaining = idx; + int src_idx = 0; + for (int d = 0; d < ndim; d++) { + int coord = remaining / out_strides[d]; + remaining = remaining % out_strides[d]; + src_idx += coord * in_strides[perm[d]]; + } + + output[idx] = input[src_idx]; +} + // ---------- Launcher functions (extern "C" for CGO) ---------- extern "C" { @@ -89,4 +141,24 @@ cudaError_t launch_transpose_nd(const float* input, float* output, return cudaGetLastError(); } +cudaError_t launch_transpose_2d_bf16(const unsigned short* input, unsigned short* output, + int rows, int cols, cudaStream_t stream) { + dim3 grid((cols + TILE_DIM - 1) / TILE_DIM, + (rows + TILE_DIM - 1) / TILE_DIM); + dim3 block(TILE_DIM, BLOCK_ROWS); + kernel_transpose_2d_bf16<<>>(input, output, rows, cols); + return cudaGetLastError(); +} + +cudaError_t launch_transpose_nd_bf16(const unsigned short* input, unsigned short* output, + const int* in_strides, const int* out_strides, + const int* perm, int ndim, int total, + cudaStream_t stream) { + int block = 256; + int grid = (total + block - 1) / block; + kernel_transpose_nd_bf16<<>>(input, output, in_strides, + out_strides, perm, ndim, total); + return cudaGetLastError(); +} + } // extern "C" diff --git a/internal/cuda/kernels/transpose.go b/internal/cuda/kernels/transpose.go index 95f3d77..52cb8dd 100644 --- a/internal/cuda/kernels/transpose.go +++ b/internal/cuda/kernels/transpose.go @@ -12,6 +12,12 @@ extern cudaError_t launch_transpose_nd(const float* input, float* output, const int* in_strides, const int* out_strides, const int* perm, int ndim, int total, cudaStream_t stream); +extern cudaError_t launch_transpose_2d_bf16(const unsigned short* input, unsigned short* output, + int rows, int cols, cudaStream_t stream); +extern cudaError_t launch_transpose_nd_bf16(const unsigned short* input, unsigned short* output, + const int* in_strides, const int* out_strides, + const int* perm, int ndim, int total, + cudaStream_t stream); */ import "C" @@ -43,3 +49,26 @@ func TransposeND(input, output unsafe.Pointer, C.int(ndim), C.int(total), stream(s), ), "transpose_nd") } + +// Transpose2DBF16 launches the tiled 2D transpose kernel for bf16 (16-bit) +// elements. Input: [rows, cols] -> Output: [cols, rows]. +func Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error { + return checkCUDA(C.launch_transpose_2d_bf16( + (*C.ushort)(input), (*C.ushort)(output), + C.int(rows), C.int(cols), stream(s), + ), "transpose_2d_bf16") +} + +// TransposeNDBF16 launches the general N-D transpose kernel for bf16 (16-bit) +// elements. +func TransposeNDBF16(input, output unsafe.Pointer, + inStrides, outStrides, perm []int32, + ndim, total int, s unsafe.Pointer) error { + return checkCUDA(C.launch_transpose_nd_bf16( + (*C.ushort)(input), (*C.ushort)(output), + (*C.int)(unsafe.Pointer(&inStrides[0])), + (*C.int)(unsafe.Pointer(&outStrides[0])), + (*C.int)(unsafe.Pointer(&perm[0])), + C.int(ndim), C.int(total), stream(s), + ), "transpose_nd_bf16") +} diff --git a/internal/cuda/kernels/transpose_purego.go b/internal/cuda/kernels/transpose_purego.go index 6a2b24a..4f592f3 100644 --- a/internal/cuda/kernels/transpose_purego.go +++ b/internal/cuda/kernels/transpose_purego.go @@ -38,3 +38,34 @@ func TransposeND(input, output unsafe.Pointer, uintptr(ndim), uintptr(total), uintptr(s)) return checkKernel(ret, "transpose_nd") } + +// Transpose2DBF16 launches the tiled 2D transpose kernel for bf16 (16-bit) +// elements. Input: [rows, cols] -> Output: [cols, rows]. +func Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error { + k := klib() + if k == nil { + return fmt.Errorf("transpose_2d_bf16 kernel: kernels not available") + } + ret := cuda.Ccall(k.launchTranspose2DBF16, + uintptr(input), uintptr(output), + uintptr(rows), uintptr(cols), uintptr(s)) + return checkKernel(ret, "transpose_2d_bf16") +} + +// TransposeNDBF16 launches the general N-D transpose kernel for bf16 (16-bit) +// elements. +func TransposeNDBF16(input, output unsafe.Pointer, + inStrides, outStrides, perm []int32, + ndim, total int, s unsafe.Pointer) error { + k := klib() + if k == nil { + return fmt.Errorf("transpose_nd_bf16 kernel: kernels not available") + } + ret := cuda.Ccall(k.launchTransposeNDBF16, + uintptr(input), uintptr(output), + uintptr(unsafe.Pointer(&inStrides[0])), + uintptr(unsafe.Pointer(&outStrides[0])), + uintptr(unsafe.Pointer(&perm[0])), + uintptr(ndim), uintptr(total), uintptr(s)) + return checkKernel(ret, "transpose_nd_bf16") +} diff --git a/internal/gpuapi/cuda_kernels.go b/internal/gpuapi/cuda_kernels.go index 6a0dbcf..0060428 100644 --- a/internal/gpuapi/cuda_kernels.go +++ b/internal/gpuapi/cuda_kernels.go @@ -190,6 +190,14 @@ func (k *CUDAKernels) TransposeND(input, output unsafe.Pointer, inStrides, outSt return kernels.TransposeND(input, output, inStrides, outStrides, perm, ndim, total, streamPtr(s)) } +func (k *CUDAKernels) Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s Stream) error { + return kernels.Transpose2DBF16(input, output, rows, cols, streamPtr(s)) +} + +func (k *CUDAKernels) TransposeNDBF16(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, s Stream) error { + return kernels.TransposeNDBF16(input, output, inStrides, outStrides, perm, ndim, total, streamPtr(s)) +} + func (k *CUDAKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, s Stream) error { //nolint:gocritic // interface match return kernels.Gather(table, indices, output, N, D, V, streamPtr(s)) } @@ -412,3 +420,4 @@ func (k *CUDAKernels) FusedEncoderFwdAvailable() bool { // Compile-time interface assertion. var _ KernelRunner = (*CUDAKernels)(nil) +var _ BFloat16Transposer = (*CUDAKernels)(nil) diff --git a/internal/gpuapi/kernels.go b/internal/gpuapi/kernels.go index a76789a..b8f4c1a 100644 --- a/internal/gpuapi/kernels.go +++ b/internal/gpuapi/kernels.go @@ -285,3 +285,15 @@ type KernelRunner interface { // FusedEncoderFwdAvailable returns true if the fused encoder kernel is loaded. FusedEncoderFwdAvailable() bool } + +// BFloat16Transposer is an optional KernelRunner extension providing on-device +// bf16 (16-bit) transpose kernels. Without it, a GPU engine over bf16 must route +// transposes to the CPU engine, whose host memcpy breaks CUDA-graph capture +// (e.g. QKL2Norm's transpose). Only the CUDA backend implements it; callers +// type-assert and fall back to the CPU transpose when it is absent. ADR-075 L4. +type BFloat16Transposer interface { + // Transpose2DBF16 transposes a [rows, cols] bf16 matrix to [cols, rows]. + Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, stream Stream) error + // TransposeNDBF16 permutes dimensions of an N-D bf16 tensor. + TransposeNDBF16(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, stream Stream) error +}