Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions compute/gpu_bf16_transpose_gemm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,65 @@ func TestGPUBF16_MatMulTransposeBBatchedParity(t *testing.T) {
}
}

func TestGPUBF16_TransposeParity(t *testing.T) {
eng := newTestGPUBF16Engine(t)
ctx := context.Background()

// bf16 transpose is a pure bitwise element move (no arithmetic), so the
// output values must match the reference EXACTLY. This exercises the native
// bf16 GPU transpose kernels (Transpose2DBF16/NDBF16) that keep bf16
// transposes on-device (CUDA-graph-capturable) instead of the CPU fallback.

// 2D: [rows, cols] -> [cols, rows] via axes [1,0].
t.Run("2D", func(t *testing.T) {
const rows, cols = 5, 7
vals := ramp(rows * cols)
x := bf16Tensor(t, []int{rows, cols}, vals)
got, err := eng.Transpose(ctx, x, []int{1, 0})
if err != nil {
t.Fatalf("Transpose 2D: %v", err)
}
if want := []int{cols, rows}; !shapeEq(got.Shape(), want) {
t.Fatalf("2D transpose shape = %v, want %v", got.Shape(), want)
}
gd := bf16ToF32(got.Data())
for i := 0; i < cols; i++ {
for j := 0; j < rows; j++ {
want := vals[j*cols+i] // input[j,i] -> output[i,j]
if gd[i*rows+j] != want {
t.Fatalf("2D[%d,%d] = %g, want %g", i, j, gd[i*rows+j], want)
}
}
}
})

// 3D: [d0,d1,d2] -> [d0,d2,d1] via axes [0,2,1] (the QKL2Norm-style case).
t.Run("3D_021", func(t *testing.T) {
const d0, d1, d2 = 4, 12, 64
vals := ramp(d0 * d1 * d2)
x := bf16Tensor(t, []int{d0, d1, d2}, vals)
got, err := eng.Transpose(ctx, x, []int{0, 2, 1})
if err != nil {
t.Fatalf("Transpose 3D: %v", err)
}
if want := []int{d0, d2, d1}; !shapeEq(got.Shape(), want) {
t.Fatalf("3D transpose shape = %v, want %v", got.Shape(), want)
}
gd := bf16ToF32(got.Data())
for a := 0; a < d0; a++ {
for c := 0; c < d2; c++ {
for b := 0; b < d1; b++ {
want := vals[a*d1*d2+b*d2+c] // input[a,b,c]
gotV := gd[a*d2*d1+c*d1+b] // output[a,c,b]
if gotV != want {
t.Fatalf("3D[%d,%d,%d] = %g, want %g", a, c, b, gotV, want)
}
}
}
}
})
}

// ramp returns n small, bf16-exact values centered near zero so GEMM sums stay
// well-conditioned (no catastrophic cancellation).
func ramp(n int) []float32 {
Expand Down
40 changes: 34 additions & 6 deletions compute/gpu_engine_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@ import (

// Transpose transposes a tensor along the given axes.
func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) {
// bf16 stays on-device when the kernel backend provides native bf16 transpose
// kernels AND the input is GPU-resident; otherwise (and for every non-f32,
// non-bf16 type) fall back to the CPU engine. Keeping bf16 transposes on the
// device is required for CUDA-graph capture: the CPU fallback's host memcpy
// breaks stream capture (e.g. QKL2Norm's transpose). ADR-075 lever L4.
var bf16Transposer gpuapi.BFloat16Transposer
bf16Path := false
if !isFloat32[T]() {
return e.cpu.Transpose(ctx, a, axes, dst...)
bt, ok := e.kernels.(gpuapi.BFloat16Transposer)
_, isGPUStore := a.GetStorage().(*tensor.GPUStorage[T])
if !isBFloat16[T]() || !ok || !isGPUStore {
return e.cpu.Transpose(ctx, a, axes, dst...)
}
bf16Transposer, bf16Path = bt, true
}

// Only use GPU path for GPU-resident tensors (Phase 6 behavior).
Expand Down Expand Up @@ -131,7 +143,11 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr OK: ptr=%p\n", devIn)
}

byteSize := total * f32Size
elemSize := f32Size
if bf16Path {
elemSize = 2 // bf16 is 16-bit
}
byteSize := total * elemSize

// Reuse dst's existing GPU memory when possible (#84).
devOut, reused := tryReuseDstPtr[T](total, dst)
Expand All @@ -149,11 +165,17 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
"rows", fmt.Sprintf("%d", shape[0]),
"cols", fmt.Sprintf("%d", shape[1]))
}
if err := e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream); err != nil {
var t2dErr error
if bf16Path {
t2dErr = bf16Transposer.Transpose2DBF16(devIn, devOut, shape[0], shape[1], e.stream)
} else {
t2dErr = e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream)
}
if t2dErr != nil {
if !reused {
e.pool.Free(e.deviceID, devOut, byteSize)
}
return nil, err
return nil, t2dErr
}
if reused {
return finishReusedDst[T](dst[0], outShape), nil
Expand Down Expand Up @@ -208,11 +230,17 @@ func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T]
devOS := unsafe.Slice((*int32)(unsafe.Add(devParams, 4*rank)), rank)
devPerm := unsafe.Slice((*int32)(unsafe.Add(devParams, 8*rank)), rank)

if err := e.kernels.TransposeND(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream); err != nil {
var ndErr error
if bf16Path {
ndErr = bf16Transposer.TransposeNDBF16(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream)
} else {
ndErr = e.kernels.TransposeND(devIn, devOut, devIS, devOS, devPerm, rank, total, e.stream)
}
if ndErr != nil {
if !reused {
e.pool.Free(e.deviceID, devOut, byteSize)
}
return nil, err
return nil, ndErr
}

if reused {
Expand Down
5 changes: 4 additions & 1 deletion internal/cuda/kernels/purego.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type KernelLib struct {
launchGatherQ8F32 uintptr

// transpose
launchTranspose2D, launchTransposeND uintptr
launchTranspose2D, launchTransposeND uintptr
launchTranspose2DBF16, launchTransposeNDBF16 uintptr

// repeat
launchRepeat uintptr
Expand Down Expand Up @@ -284,6 +285,8 @@ func openKernelLib() (*KernelLib, error) {
// transpose
{"launch_transpose_2d", &k.launchTranspose2D},
{"launch_transpose_nd", &k.launchTransposeND},
{"launch_transpose_2d_bf16", &k.launchTranspose2DBF16},
{"launch_transpose_nd_bf16", &k.launchTransposeNDBF16},
// repeat
{"launch_repeat", &k.launchRepeat},
// gemm_q4
Expand Down
72 changes: 72 additions & 0 deletions internal/cuda/kernels/transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,58 @@ __global__ void kernel_transpose_nd(const float* __restrict__ input,
output[idx] = input[src_idx];
}

// ---------- bf16 transpose ----------
// Transpose is pure data movement (no arithmetic), so the bf16 variants operate
// on 16-bit elements via `unsigned short` -- a bitwise copy independent of the
// bf16 numeric interpretation. This keeps bf16 transposes on-device so they are
// CUDA-graph-capturable (the f32-only kernels forced bf16 onto the CPU engine,
// whose host memcpy breaks stream capture). ADR-075 lever L4.

__global__ void kernel_transpose_2d_bf16(const unsigned short* __restrict__ input,
unsigned short* __restrict__ output,
int rows, int cols) {
__shared__ unsigned short tile[TILE_DIM][TILE_DIM + 1];

int xIdx = blockIdx.x * TILE_DIM + threadIdx.x;
int yIdx = blockIdx.y * TILE_DIM + threadIdx.y;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if ((yIdx + j) < rows && xIdx < cols) {
tile[threadIdx.y + j][threadIdx.x] = input[(yIdx + j) * cols + xIdx];
}
}
__syncthreads();

int outX = blockIdx.y * TILE_DIM + threadIdx.x;
int outY = blockIdx.x * TILE_DIM + threadIdx.y;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if ((outY + j) < cols && outX < rows) {
output[(outY + j) * rows + outX] = tile[threadIdx.x][threadIdx.y + j];
}
}
}

__global__ void kernel_transpose_nd_bf16(const unsigned short* __restrict__ input,
unsigned short* __restrict__ output,
const int* __restrict__ in_strides,
const int* __restrict__ out_strides,
const int* __restrict__ perm,
int ndim, int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;

int remaining = idx;
int src_idx = 0;
for (int d = 0; d < ndim; d++) {
int coord = remaining / out_strides[d];
remaining = remaining % out_strides[d];
src_idx += coord * in_strides[perm[d]];
}

output[idx] = input[src_idx];
}

// ---------- Launcher functions (extern "C" for CGO) ----------

extern "C" {
Expand All @@ -89,4 +141,24 @@ cudaError_t launch_transpose_nd(const float* input, float* output,
return cudaGetLastError();
}

cudaError_t launch_transpose_2d_bf16(const unsigned short* input, unsigned short* output,
int rows, int cols, cudaStream_t stream) {
dim3 grid((cols + TILE_DIM - 1) / TILE_DIM,
(rows + TILE_DIM - 1) / TILE_DIM);
dim3 block(TILE_DIM, BLOCK_ROWS);
kernel_transpose_2d_bf16<<<grid, block, 0, stream>>>(input, output, rows, cols);
return cudaGetLastError();
}

cudaError_t launch_transpose_nd_bf16(const unsigned short* input, unsigned short* output,
const int* in_strides, const int* out_strides,
const int* perm, int ndim, int total,
cudaStream_t stream) {
int block = 256;
int grid = (total + block - 1) / block;
kernel_transpose_nd_bf16<<<grid, block, 0, stream>>>(input, output, in_strides,
out_strides, perm, ndim, total);
return cudaGetLastError();
}

} // extern "C"
29 changes: 29 additions & 0 deletions internal/cuda/kernels/transpose.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ extern cudaError_t launch_transpose_nd(const float* input, float* output,
const int* in_strides, const int* out_strides,
const int* perm, int ndim, int total,
cudaStream_t stream);
extern cudaError_t launch_transpose_2d_bf16(const unsigned short* input, unsigned short* output,
int rows, int cols, cudaStream_t stream);
extern cudaError_t launch_transpose_nd_bf16(const unsigned short* input, unsigned short* output,
const int* in_strides, const int* out_strides,
const int* perm, int ndim, int total,
cudaStream_t stream);
*/
import "C"

Expand Down Expand Up @@ -43,3 +49,26 @@ func TransposeND(input, output unsafe.Pointer,
C.int(ndim), C.int(total), stream(s),
), "transpose_nd")
}

// Transpose2DBF16 launches the tiled 2D transpose kernel for bf16 (16-bit)
// elements. Input: [rows, cols] -> Output: [cols, rows].
func Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error {
return checkCUDA(C.launch_transpose_2d_bf16(
(*C.ushort)(input), (*C.ushort)(output),
C.int(rows), C.int(cols), stream(s),
), "transpose_2d_bf16")
}

// TransposeNDBF16 launches the general N-D transpose kernel for bf16 (16-bit)
// elements.
func TransposeNDBF16(input, output unsafe.Pointer,
inStrides, outStrides, perm []int32,
ndim, total int, s unsafe.Pointer) error {
return checkCUDA(C.launch_transpose_nd_bf16(
(*C.ushort)(input), (*C.ushort)(output),
(*C.int)(unsafe.Pointer(&inStrides[0])),
(*C.int)(unsafe.Pointer(&outStrides[0])),
(*C.int)(unsafe.Pointer(&perm[0])),
C.int(ndim), C.int(total), stream(s),
), "transpose_nd_bf16")
}
31 changes: 31 additions & 0 deletions internal/cuda/kernels/transpose_purego.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,34 @@ func TransposeND(input, output unsafe.Pointer,
uintptr(ndim), uintptr(total), uintptr(s))
return checkKernel(ret, "transpose_nd")
}

// Transpose2DBF16 launches the tiled 2D transpose kernel for bf16 (16-bit)
// elements. Input: [rows, cols] -> Output: [cols, rows].
func Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s unsafe.Pointer) error {
k := klib()
if k == nil {
return fmt.Errorf("transpose_2d_bf16 kernel: kernels not available")
}
ret := cuda.Ccall(k.launchTranspose2DBF16,
uintptr(input), uintptr(output),
uintptr(rows), uintptr(cols), uintptr(s))
return checkKernel(ret, "transpose_2d_bf16")
}

// TransposeNDBF16 launches the general N-D transpose kernel for bf16 (16-bit)
// elements.
func TransposeNDBF16(input, output unsafe.Pointer,
inStrides, outStrides, perm []int32,
ndim, total int, s unsafe.Pointer) error {
k := klib()
if k == nil {
return fmt.Errorf("transpose_nd_bf16 kernel: kernels not available")
}
ret := cuda.Ccall(k.launchTransposeNDBF16,
uintptr(input), uintptr(output),
uintptr(unsafe.Pointer(&inStrides[0])),
uintptr(unsafe.Pointer(&outStrides[0])),
uintptr(unsafe.Pointer(&perm[0])),
uintptr(ndim), uintptr(total), uintptr(s))
return checkKernel(ret, "transpose_nd_bf16")
}
9 changes: 9 additions & 0 deletions internal/gpuapi/cuda_kernels.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ func (k *CUDAKernels) TransposeND(input, output unsafe.Pointer, inStrides, outSt
return kernels.TransposeND(input, output, inStrides, outStrides, perm, ndim, total, streamPtr(s))
}

func (k *CUDAKernels) Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, s Stream) error {
return kernels.Transpose2DBF16(input, output, rows, cols, streamPtr(s))
}

func (k *CUDAKernels) TransposeNDBF16(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, s Stream) error {
return kernels.TransposeNDBF16(input, output, inStrides, outStrides, perm, ndim, total, streamPtr(s))
}

func (k *CUDAKernels) Gather(table, indices, output unsafe.Pointer, N, D, V int, s Stream) error { //nolint:gocritic // interface match
return kernels.Gather(table, indices, output, N, D, V, streamPtr(s))
}
Expand Down Expand Up @@ -412,3 +420,4 @@ func (k *CUDAKernels) FusedEncoderFwdAvailable() bool {

// Compile-time interface assertion.
var _ KernelRunner = (*CUDAKernels)(nil)
var _ BFloat16Transposer = (*CUDAKernels)(nil)
12 changes: 12 additions & 0 deletions internal/gpuapi/kernels.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,15 @@ type KernelRunner interface {
// FusedEncoderFwdAvailable returns true if the fused encoder kernel is loaded.
FusedEncoderFwdAvailable() bool
}

// BFloat16Transposer is an optional KernelRunner extension providing on-device
// bf16 (16-bit) transpose kernels. Without it, a GPU engine over bf16 must route
// transposes to the CPU engine, whose host memcpy breaks CUDA-graph capture
// (e.g. QKL2Norm's transpose). Only the CUDA backend implements it; callers
// type-assert and fall back to the CPU transpose when it is absent. ADR-075 L4.
type BFloat16Transposer interface {
// Transpose2DBF16 transposes a [rows, cols] bf16 matrix to [cols, rows].
Transpose2DBF16(input, output unsafe.Pointer, rows, cols int, stream Stream) error
// TransposeNDBF16 permutes dimensions of an N-D bf16 tensor.
TransposeNDBF16(input, output unsafe.Pointer, inStrides, outStrides, perm []int32, ndim, total int, stream Stream) error
}
Loading
Loading