diff --git a/compute/gpu_bf16_transpose_gemm_test.go b/compute/gpu_bf16_transpose_gemm_test.go index 0c6d47f..a4ca7b2 100644 --- a/compute/gpu_bf16_transpose_gemm_test.go +++ b/compute/gpu_bf16_transpose_gemm_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/zerfoo/float16" + "github.com/zerfoo/ztensor/tensor" ) func TestGPUBF16_MatMulTransposeBParity(t *testing.T) { @@ -193,6 +194,45 @@ func TestGPUBF16_TransposeParity(t *testing.T) { }) } +func TestGPUBF16_ReshapeStaysOnDevice(t *testing.T) { + eng := newTestGPUBF16Engine(t) + ctx := context.Background() + + // A GPU-resident bf16 tensor (output of a GPU op) reshaped must stay a + // GPUStorage view, not bounce to the CPU engine -- otherwise the next op + // (e.g. Transpose feeding QKL2Norm) is forced onto the CPU and breaks + // CUDA-graph capture. + vals := ramp(2 * 3 * 4) + x := bf16Tensor(t, []int{2, 3, 4}, vals) + // Mul produces a device-resident GPUStorage[bf16] result. + gpuRes, err := eng.Mul(ctx, x, x) + if err != nil { + t.Fatalf("Mul: %v", err) + } + if _, ok := gpuRes.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok { + t.Fatalf("precondition: Mul result storage = %T, want *GPUStorage[bf16]", gpuRes.GetStorage()) + } + + r, err := eng.Reshape(ctx, gpuRes, []int{6, 4}, nil) + if err != nil { + t.Fatalf("Reshape: %v", err) + } + if want := []int{6, 4}; !shapeEq(r.Shape(), want) { + t.Fatalf("Reshape shape = %v, want %v", r.Shape(), want) + } + if _, ok := r.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok { + t.Fatalf("Reshape result storage = %T, want *GPUStorage[bf16] (must stay on device)", r.GetStorage()) + } + // Data preserved (x*x). + rd := bf16ToF32(r.Data()) + for i := range vals { + want := float16.BFloat16FromFloat32(vals[i] * vals[i]).ToFloat32() + if rd[i] != want { + t.Fatalf("Reshape[%d] = %g, want %g", i, rd[i], want) + } + } +} + // ramp returns n small, bf16-exact values centered near zero so GEMM sums stay // well-conditioned (no catastrophic cancellation). func ramp(n int) []float32 { diff --git a/compute/gpu_engine_memory.go b/compute/gpu_engine_memory.go index 42c62af..14e5aee 100644 --- a/compute/gpu_engine_memory.go +++ b/compute/gpu_engine_memory.go @@ -712,8 +712,12 @@ func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], } } - // GPUStorage[T]: zero-copy reshape. - if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize { + // GPUStorage[T]: zero-copy reshape. Reshape is a pure metadata/view operation + // (no data movement), so it is valid for any element type whose data lives in + // a GPUStorage[T]. bf16 must stay on this GPU view path: routing it to the CPU + // reshape produces a host tensor that then forces the next op (e.g. Transpose + // feeding QKL2Norm) onto the CPU engine, breaking CUDA-graph capture. ADR-075 L4. + if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && (isFloat32[T]() || isBFloat16[T]()) && newSize == currentSize { view := gs.View(gs.Len()) if len(dst) > 0 && dst[0] != nil { aliasReshapeDst(dst[0], inferredShape, view)