Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions compute/gpu_bf16_transpose_gemm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"testing"

"github.com/zerfoo/float16"
"github.com/zerfoo/ztensor/tensor"
)

func TestGPUBF16_MatMulTransposeBParity(t *testing.T) {
Expand Down Expand Up @@ -193,6 +194,45 @@ func TestGPUBF16_TransposeParity(t *testing.T) {
})
}

func TestGPUBF16_ReshapeStaysOnDevice(t *testing.T) {
eng := newTestGPUBF16Engine(t)
ctx := context.Background()

// A GPU-resident bf16 tensor (output of a GPU op) reshaped must stay a
// GPUStorage view, not bounce to the CPU engine -- otherwise the next op
// (e.g. Transpose feeding QKL2Norm) is forced onto the CPU and breaks
// CUDA-graph capture.
vals := ramp(2 * 3 * 4)
x := bf16Tensor(t, []int{2, 3, 4}, vals)
// Mul produces a device-resident GPUStorage[bf16] result.
gpuRes, err := eng.Mul(ctx, x, x)
if err != nil {
t.Fatalf("Mul: %v", err)
}
if _, ok := gpuRes.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok {
t.Fatalf("precondition: Mul result storage = %T, want *GPUStorage[bf16]", gpuRes.GetStorage())
}

r, err := eng.Reshape(ctx, gpuRes, []int{6, 4}, nil)
if err != nil {
t.Fatalf("Reshape: %v", err)
}
if want := []int{6, 4}; !shapeEq(r.Shape(), want) {
t.Fatalf("Reshape shape = %v, want %v", r.Shape(), want)
}
if _, ok := r.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok {
t.Fatalf("Reshape result storage = %T, want *GPUStorage[bf16] (must stay on device)", r.GetStorage())
}
// Data preserved (x*x).
rd := bf16ToF32(r.Data())
for i := range vals {
want := float16.BFloat16FromFloat32(vals[i] * vals[i]).ToFloat32()
if rd[i] != want {
t.Fatalf("Reshape[%d] = %g, want %g", i, rd[i], want)
}
}
}

// ramp returns n small, bf16-exact values centered near zero so GEMM sums stay
// well-conditioned (no catastrophic cancellation).
func ramp(n int) []float32 {
Expand Down
8 changes: 6 additions & 2 deletions compute/gpu_engine_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,12 @@ func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T],
}
}

// GPUStorage[T]: zero-copy reshape.
if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize {
// GPUStorage[T]: zero-copy reshape. Reshape is a pure metadata/view operation
// (no data movement), so it is valid for any element type whose data lives in
// a GPUStorage[T]. bf16 must stay on this GPU view path: routing it to the CPU
// reshape produces a host tensor that then forces the next op (e.g. Transpose
// feeding QKL2Norm) onto the CPU engine, breaking CUDA-graph capture. ADR-075 L4.
if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && (isFloat32[T]() || isBFloat16[T]()) && newSize == currentSize {
view := gs.View(gs.Len())
if len(dst) > 0 && dst[0] != nil {
aliasReshapeDst(dst[0], inferredShape, view)
Expand Down
Loading