diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index f50c0064dd956..89dcf718e8bf5 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -27,9 +27,9 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' - docker_image_repo: onnxruntimecuda12manylinuxbuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' + docker_image_repo: onnxruntimecuda13manylinuxbuild + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 --cuda_version=13.0 --cuda_home=/usr/local/cuda-13.0 --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -57,8 +57,8 @@ jobs: id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda12manylinuxbuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda13manylinuxbuild + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -91,6 +91,15 @@ jobs: echo "Warning: perms.txt not found in artifact." fi + # Verify the GPU is accessible inside Docker before running the full test suite. + # If the NVIDIA Container Toolkit fails to expose /dev/nvidia* devices, + # tests will fail with "CUDA failure 100" and waste 10+ minutes. + - name: Verify GPU access in Docker + run: | + docker run --rm --gpus all \ + "${{ steps.build_docker_image_step.outputs.full-image-name }}" \ + nvidia-smi + # --- Run Tests using the downloaded build --- # The run-build-script-in-docker action mounts ${{ runner.temp }} to /onnxruntime_src/build # So build.py --build_dir build/Release inside the container correctly finds the artifacts. @@ -102,5 +111,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda' - extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=13.0 --cuda_home=/usr/local/cuda-13.0 --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index d2491f59812ab..a9197b3732dd8 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -26,17 +26,17 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' - docker_image_repo: onnxruntimecuda12manylinuxbuild + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' + docker_image_repo: onnxruntimecuda13manylinuxbuild extra_build_flags: >- --use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 - --cuda_version=12.8 - --cuda_home=/usr/local/cuda-12.8 - --cudnn_home=/usr/local/cuda-12.8 + --cuda_version=13.0 + --cuda_home=/usr/local/cuda-13.0 + --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 --cmake_extra_defines onnxruntime_QUICK_BUILD=ON @@ -67,8 +67,8 @@ jobs: id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda12manylinuxbuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda13manylinuxbuild + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -100,6 +100,15 @@ jobs: echo "Warning: perms.txt not found in artifact." fi + # Verify the GPU is accessible inside Docker before running the full test suite. + # If the NVIDIA Container Toolkit fails to expose /dev/nvidia* devices, + # tests will fail with "CUDA failure 100" and waste 10+ minutes. + - name: Verify GPU access in Docker + run: | + docker run --rm --gpus all \ + "${{ steps.build_docker_image_step.outputs.full-image-name }}" \ + nvidia-smi + # --- Install the ORT wheel and run CUDA plugin EP tests --- - name: Run CUDA Plugin EP Python Tests run: | @@ -111,6 +120,11 @@ jobs: bash -c " set -ex export PATH=/opt/python/cp312-cp312/bin:\$PATH + # Ensure libcudart.so.13 is findable regardless of host-runner NVIDIA Container Toolkit configuration. + # The CUDA runtime library lives in the container image at /usr/local/cuda-13.0/lib64, but the + # LD_LIBRARY_PATH may not include this path when the runner's NVIDIA toolkit only mounts driver + # libraries at /usr/local/nvidia/lib64. + export LD_LIBRARY_PATH=/usr/local/cuda-13.0/lib64:\${LD_LIBRARY_PATH:-} # Install the ORT wheel python -m pip install /build/Release/Release/dist/onnxruntime*.whl diff --git a/.github/workflows/nightly_webgpu.yml b/.github/workflows/nightly_webgpu.yml new file mode 100644 index 0000000000000..b3da29a2f0bd4 --- /dev/null +++ b/.github/workflows/nightly_webgpu.yml @@ -0,0 +1,77 @@ +name: Nightly ONNX Runtime WebGPU Builds + +on: + schedule: + - cron: '0 9 * * *' # Daily at 09:00 UTC + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + webgpu_shader_key_validation: + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_shader_validation-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] + timeout-minutes: 90 + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: "0" + ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: none + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r tools\ci_build\github\windows\python\requirements.txt + shell: cmd + working-directory: ${{ github.workspace }} + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: "24" + + - name: Build and Test + shell: pwsh + run: | + $env:ORT_WEBGPU_EP_SHADER_DUMP_FILE = "${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log" + + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --use_webgpu ` + --wgsl_template static ` + --cmake_extra_defines onnxruntime_BUILD_DAWN_SHARED_LIBRARY=ON ` + --update ` + --build --parallel ` + --test + + - name: Check log file + shell: cmd + run: | + dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log + + - name: Validate shader keys + uses: ./.github/actions/webgpu-validate-shader-key + with: + log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 53c7031c3c095..dcc314084e4e2 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -157,6 +157,7 @@ jobs: runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "1ES.ImageOverride=onnxruntime-Win-CPU-VS2022-Latest-NVMe-x64-test", "JobId=windows-cuda-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" ] steps: @@ -222,6 +223,13 @@ jobs: with: whl-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo\dist + # Verify the GPU is accessible before running the full test suite. + # If the NVIDIA driver is not available, tests will fail with + # "CUDA failure 100" and waste significant time. + - name: Verify GPU access + shell: pwsh + run: nvidia-smi + - name: Run Tests working-directory: ${{ runner.temp }} run: | diff --git a/.github/workflows/windows_cuda_plugin.yml b/.github/workflows/windows_cuda_plugin.yml index f9acdbd76a12d..6b6b7f7158df3 100644 --- a/.github/workflows/windows_cuda_plugin.yml +++ b/.github/workflows/windows_cuda_plugin.yml @@ -127,6 +127,7 @@ jobs: runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "1ES.ImageOverride=onnxruntime-Win-CPU-VS2022-Latest-NVMe-x64-test", "JobId=windows-cuda-plugin-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" ] steps: @@ -187,6 +188,13 @@ jobs: with: whl-directory: ${{ runner.temp }}\build\Release\Release\dist + # Verify the GPU is accessible before running the full test suite. + # If the NVIDIA driver is not available, tests will fail with + # "CUDA failure 100" and waste significant time. + - name: Verify GPU access + shell: pwsh + run: nvidia-smi + - name: Run CUDA Plugin EP Python Tests working-directory: ${{ github.workspace }}\onnxruntime\test\python\transformers shell: pwsh diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 8c7df780735f1..bc64d394b6062 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -55,6 +55,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qlutgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/flashattn_qkv.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index cbd4a38ae18f0..de1d7559a1572 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -242,6 +242,7 @@ if (onnxruntime_USE_CUDA AND NOT WIN32) ) include(cutlass) target_include_directories(onnxruntime_pybind11_state PRIVATE ${cutlass_SOURCE_DIR}/include) + target_link_libraries(onnxruntime_pybind11_state PRIVATE CUDA::cudart) endif() if (onnxruntime_USE_CUDA AND WIN32) target_compile_definitions(onnxruntime_pybind11_state PRIVATE ORT_NO_CUDA_IN_PYBIND) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 0a144132b5c86..e5a211c9fd11a 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,6 +17,7 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) The operator schema itself is defined in: @@ -47,6 +48,13 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. +The quantized path has two execution strategies: + +- **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. +- **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. + ## Supported Cache Modes ### Non-quantized cache @@ -85,7 +93,11 @@ For INT4, two signed 4-bit values are stored in each byte. The packed head dimen During quantized execution, new key/value vectors are quantized on write into the present cache. Existing past-cache data and newly written present-cache data are then consumed by MLAS quantized GEMM helpers. -## QK GEMM +## Naive Path: QK GEMM + Softmax + SV GEMM + +The naive (full materialization) path executes attention as three separate stages: + +### QK GEMM The QK stage computes: @@ -102,7 +114,7 @@ For quantized K cache, the CPU path calls `MlasQKGemm` with: The default MLAS contract is exact with respect to the FP32 query operand: only the K cache is dequantized on the fly. The query row is not quantized by default. -## Softmax and Masking +### Softmax and Masking After QK GEMM, the CPU path applies the same attention-score processing used by the non-quantized path, including supported combinations of: @@ -115,7 +127,7 @@ After QK GEMM, the CPU path applies the same attention-score processing used by The quantized cache mode does not change these score-processing semantics. -## SV GEMM +### SV GEMM The SV stage computes: @@ -132,6 +144,66 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. +## Flash Attention Path + +The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasQKGemm` on a block slice of quantized K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **Fused SV accumulate** — `MlasSVGemm(..., Beta=1.0)` dequantizes V on the fly and accumulates `softmax(QK_block) × V_block` into the output in a single pass (no intermediate FP32 buffer) +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +### Activation Conditions + +The flash path is selected when ALL of the following hold: + +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `total_sequence_length > 1` +- No softcap +- No smooth softmax +- No head sink +- No output QK capture + +Attention bias is fully supported in the flash path (applied per-tile after QK GEMM). The bias tensor shape `[B|1, N|1, S, T]` supports broadcast along both batch and head dimensions. + +When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Size Selection + +Block sizes are chosen based on L2 cache size: + +- `kv_block_size (Bc)`: Sized so that a full KV block's scores + dequantized V fit within L2. Typical values: 128–256. +- `q_block_size (Br)`: Sized for the query tile. Typical value: 64. + +### Threading + +The flash kernel parallelizes across `(batch, head, q_block)` tiles using the ORT intra-op thread pool. Each thread gets a private working buffer containing space for: + +- `l[Br]` and `m[Br]` — running softmax statistics +- `scores[Br × Bc]` — QK scores for current KV block +- `temp_output[Br × H]` — accumulated output + +The V dequantization temp buffer was eliminated by fusing dequantization into `MlasSVGemm` with `Beta=1.0` (accumulate mode). This reduces per-thread buffer size by `Bc × H × 4` bytes (e.g., 64 KB for Bc=128, H=128). + +### Flash Decoding (Decode Optimization) + +For decode steps (`sequence_length == 1`), the standard `(batch, head, q_block)` partitioning yields only `batch × num_heads` tasks, which can underutilize thread pools on machines with many cores (e.g., 96 threads with batch=1, num_heads=32 produces only 32 tasks). + +When `batch × num_heads < thread_count` and `kv_chunk_count > 1`, the kernel switches to a **flash decoding** strategy that also partitions along the KV sequence dimension: + +- **Phase 1** (parallel over `batch × num_heads × kv_chunk_count` tasks): Each thread computes partial attention for one KV chunk, producing per-chunk `(m, l, S_exp × V)` stored in a partials buffer. +- **Phase 2** (parallel over `batch × num_heads` tasks): Merge partials using log-sum-exp rescaling: `output = Σ_c(exp(m_c − m_global) × partial_c) / Σ_c(exp(m_c − m_global) × l_c)`. + +The partials buffer is allocated alongside the per-thread scratch in a single allocation: +- Per-thread scratch: `scores[Bc]` (one float per KV block element) +- Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. @@ -168,7 +240,7 @@ CPU GroupQueryAttention coverage is split across operator-level and MLAS-level t - `onnxruntime/test/mlas/unittest/test_qkv_quant.cpp` - MLAS `MlasKVQuantize`, `MlasKVDequantize`, `MlasQKGemm`, and `MlasSVGemm` contract tests. -The MLAS benchmark for quantized KV-cache GEMM is: +The MLAS benchmark for quantized KV-cache GEMM and flash attention is: - `onnxruntime/test/mlas/bench/bench_qkv_quant.cpp` @@ -223,6 +295,23 @@ ORT_MLAS_QKGEMM_S8_APPROX_VNNI=1 ./onnxruntime_mlas_benchmark \ --benchmark_report_aggregates_only=true ``` +Run flash vs naive full-attention benchmark: + +```bash +cd build/cpu_test/Release +./onnxruntime_mlas_benchmark \ + --benchmark_filter='BM_GQA_(Naive|Flash)' \ + --benchmark_min_time=0.5s \ + --benchmark_repetitions=3 \ + --benchmark_report_aggregates_only=true +``` + +To force the naive path at the operator level (for A/B testing during inference): + +```bash +ORT_GQA_DISABLE_FLASH_ATTENTION=1 ./your_inference_app +``` + ### Updated benchmark results The following results were measured on an Intel Xeon Platinum 8480C, 96 CPUs, using the CPU Release benchmark binary. Shape: `M=1`, `N=512`, `K=128`, INT8 per-tensor QKGemm. @@ -236,6 +325,110 @@ The following results were measured on an Intel Xeon Platinum 8480C, 96 CPUs, us For comparison, the earlier PR description reported the approximate AVX512 VNNI path at 1,938 ns for this shape, with scalar at 30,179 ns and AVX2 at 4,219 ns. The default AVX512 path is now the exact FP32 fused-dequant implementation, so it is slower than approximate VNNI but preserves the `MlasQKGemm` FP32-query contract. +### Flash Attention vs Naive benchmark results + +Measured on Intel Xeon Platinum 8480C, 96 CPUs. INT8 quantized KV cache, threads=8. + +Two benchmark levels are reported: +- **Operator-level** (`benchmark_gqa_cpu_flash.py`): Measures the full GQA operator via `InferenceSession`, including KV cache concatenation, quantization of new K/V, and Python/C++ boundary overhead. +- **MLAS kernel-level** (`bench_qkv_quant.cpp`): Measures only the attention kernel (QK+softmax+SV), isolating the algorithmic gain from operator overhead. + +```bash +# Operator-level Python benchmark: +cd /tmp +PYTHONPATH=build/cpu/Release python \ + onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py --warmup 5 --repeats 20 + +# MLAS kernel-level C++ benchmark: +cd build/cpu/Release +./onnxruntime_mlas_benchmark \ + --benchmark_filter='BM_GQA_(Naive|Flash)' \ + --benchmark_min_time=0.5s \ + --benchmark_repetitions=3 \ + --benchmark_report_aggregates_only=true +``` + +#### Latency — Prefill (S = T, prompt phase) + +Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128, INT8 per-tensor. + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 7.7 | 8.9 | 0.9x | operator | +| 1024 | 36.8 | 30.2 | 1.2x | operator | +| 2048 | 157.9 | 110.2 | 1.4x | operator | +| 4096 | 790.6 | 427.1 | 1.9x | operator | +| 512 | 9.9 | 8.1 | 1.2x | MLAS kernel | +| 1024 | 44.4 | 27.0 | 1.6x | MLAS kernel | +| 2048 | 190.9 | 116.9 | 1.6x | MLAS kernel | +| 4096 | 1257.8 | 461.6 | 2.7x | MLAS kernel | + +The operator-level naive path is faster than the MLAS-level naive at small S because the naive path's QK GEMM batches all heads in one call, amortizing thread dispatch. At larger S, the flash kernel's O(S×Bc) tiling wins decisively. + +MLAS kernel-level per-channel results: + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 10.7 | 10.8 | 1.0x | MLAS kernel | +| 1024 | 49.5 | 41.7 | 1.2x | MLAS kernel | +| 2048 | 212.1 | 164.1 | 1.3x | MLAS kernel | +| 4096 | 1223.9 | 607.8 | 2.0x | MLAS kernel | + +#### Latency — Decode (S = 1, token generation) + +Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128, INT8 per-tensor. +Flash decoding is NOT active for this config (batch×heads=16 > threads=8). + +| Total Seqlen | Naive | Flash | Speedup | Source | +|---:|---:|---:|---:|:---| +| 513 | 0.133 ms | 0.149 ms | 0.9x | operator | +| 1025 | 0.258 ms | 0.224 ms | 1.2x | operator | +| 2049 | 0.453 ms | 0.394 ms | 1.2x | operator | +| 4097 | 0.681 ms | 0.679 ms | 1.0x | operator | +| 512 | 32 us | 22 us | 1.4x | MLAS kernel | +| 1024 | 71 us | 47 us | 1.5x | MLAS kernel | +| 2048 | 120 us | 87 us | 1.4x | MLAS kernel | +| 4096 | 210 us | 174 us | 1.2x | MLAS kernel | + +At the MLAS kernel level, the flash path is consistently 1.2–1.5x faster for decode due to fused single-pass KV access (better cache locality). At the operator level, the gain is partially masked by KV cache concatenation overhead (~100us), which dominates at short sequences but becomes less significant at longer ones. + +MLAS kernel-level per-channel decode results: + +| Total Seqlen | Naive (us) | Flash (us) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 53 | 31 | 1.7x | MLAS kernel | +| 1024 | 86 | 52 | 1.7x | MLAS kernel | +| 2048 | 172 | 97 | 1.8x | MLAS kernel | +| 4096 | 299 | 191 | 1.6x | MLAS kernel | + +#### Latency — Flash Decoding (S = 1, KV partitioned across threads) + +Shape: B=1, num_heads=4, kv_num_heads=4 (MHA), head_size=128, threads=8. +Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle threads). + +| Total Seqlen | Naive (us) | Flash (us) | Speedup | Quant | +|---:|---:|---:|---:|:---| +| 512 | 31 | 25 | 1.2x | per-tensor | +| 1024 | 41 | 25 | 1.6x | per-tensor | +| 2048 | 67 | 34 | 2.0x | per-tensor | +| 4096 | 197 | 54 | 3.7x | per-tensor | +| 512 | 25 | 28 | 0.9x | per-channel | +| 1024 | 72 | 27 | 2.7x | per-channel | +| 2048 | 144 | 37 | 3.9x | per-channel | +| 4096 | 304 | 60 | 5.1x | per-channel | + +(Source: MLAS kernel-level benchmark) + +#### Peak Memory — Prefill (S = T, prompt phase) + +| Seq Length | Naive Peak (MB) | Flash Peak (MB) | Memory Reduction | +|---:|---:|---:|---:| +| 2048 | +294 | +44 | 6.7x | +| 4096 | +1107 | +82 | 13.5x | +| 4096 (N=32) | +2131 | +87 | 24.5x | + +**Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. + ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: @@ -246,7 +439,8 @@ The current CPU GroupQueryAttention implementation has a few important limitatio - INT4 cache storage uses packed `uint8` bytes and requires consumers to use the packed head dimension. - The default AVX512 quantized KV-cache GEMM path preserves FP32 query and attention-probability operands; the approximate VNNI QK path is opt-in only. - Hardware dispatch affects performance, but should not change default numeric semantics. -- The MLAS quantized GEMM helpers operate on one per-batch/per-head tile at a time; outer parallelism is managed by the GQA kernel. +- The flash attention path does not support softcap, smooth softmax, head sink, or QK output capture. These features fall back to the naive path. +- The MLAS quantized GEMM helpers operate on one per-batch/per-head tile at a time; outer parallelism is managed by the GQA kernel (or by the flash attention kernel internally). ## Future Work @@ -254,7 +448,6 @@ Further optimization opportunities include: - Improve the exact AVX512 INT8 per-tensor QK path without quantizing the FP32 query, for example by processing multiple K-cache rows per query row while keeping FP32 FMA semantics. - Add AVX512-specific exact micro-kernels for common decode shapes such as `M=1`, `N=512/2048`, and `K=64/128`. -- Add dispatch-specific benchmark coverage for prefill shapes (`M > 1`) and longer cache lengths. - Add dedicated accuracy/performance tests for the approximate VNNI opt-in path before enabling it in any production configuration. - Reduce temporary copies in quantized cache concatenation when past and present buffers cannot be shared directly. - Explore prepacking or layout transforms for long-lived quantized KV caches when the cache update pattern makes that worthwhile. @@ -279,7 +472,14 @@ CPU features that are limited or not implemented relative to the broader operato - quantizes new K/V values into the present cache - concatenates past and present cache chunks when needed - calls `MlasQKGemm` and `MlasSVGemm` +- `GroupQueryAttentionBase::ApplyAttentionQuantizedFlash(...)` + - concatenates new K/V into present cache (parallel over batch × kv_heads) + - invokes `MlasFlashAttentionQuantizedKV` with L2-cache-aware block sizes - `MlasQKGemm(...)` - computes FP32 query times quantized K cache transpose - `MlasSVGemm(...)` - - computes FP32 attention probabilities times quantized V cache + - computes `C = Beta*C + A*dequant(B)` where A is FP32 attention probabilities and B is quantized V cache + - `Beta=0` (overwrite) for naive path; `Beta=1.0` (accumulate) for flash path +- `MlasFlashAttentionQuantizedKV(...)` + - flash attention kernel with online softmax, tiled QK/SV over quantized KV cache + - parallelizes across (batch, head, q_block) tiles via thread pool diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 7a30667befffd..12f61cddea18c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -13,6 +14,9 @@ #include "core/common/safeint.h" #include "core/framework/op_kernel.h" #include "core/mlas/inc/mlas_qkv_quant.h" +#include "core/platform/env.h" +#include "core/platform/env_var_utils.h" +#include "core/platform/threadpool.h" #include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" namespace onnxruntime { @@ -93,6 +97,8 @@ class GQAAttentionBase { kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); kv_quant_enabled_ = (k_quant_type_ != KVQuantizationType::NONE); + disable_gqa_flash_ = ParseEnvironmentVariableWithDefault("ORT_GQA_DISABLE_FLASH_ATTENTION", false); + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); } @@ -111,6 +117,7 @@ class GQAAttentionBase { KVQuantizationType v_quant_type_; int kv_cache_bit_width_; bool kv_quant_enabled_; + bool disable_gqa_flash_; template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH @@ -536,11 +543,363 @@ class GQAAttentionBase { MlasSVGemm(sequence_length, head_size, total_seqlen, attention_probs + probs_offset, seqlen_present_kv_cache, v_quantized, quant_type, head_v_scale, - output_current, hidden_size, nullptr); + output_current, hidden_size, 0.0f, nullptr); + } + }); + } + + return Status::OK(); + } + + // Flash Attention style tiled computation for quantized KV cache. + // Avoids materializing the full [B, N, S, T] attention probability matrix. + // Uses online softmax with KV block tiling for reduced memory usage. + Status ApplyAttentionQuantizedFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (uint8_t) + const Tensor* past_value, // past V (uint8_t) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (uint8_t) + Tensor* present_value, // present V (uint8_t) + const Tensor* seqlens_k, + const float* k_scale, + const float* v_scale, + MLAS_KV_QUANT_TYPE quant_type, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for quantized KV cache"); + + // Access cache data as raw bytes + const uint8_t* past_key_data = nullptr; + uint8_t* present_key_data = nullptr; + const uint8_t* past_value_data = nullptr; + uint8_t* present_value_data = nullptr; + if (kv_cache_bit_width_ == 4) { + past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + present_key_data = present_key->MutableData(); + past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + present_value_data = present_value->MutableData(); + } else { + past_key_data = past_key != nullptr ? reinterpret_cast(past_key->Data()) : nullptr; + present_key_data = reinterpret_cast(present_key->MutableData()); + past_value_data = past_value != nullptr ? reinterpret_cast(past_value->Data()) : nullptr; + present_value_data = reinterpret_cast(present_value->MutableData()); + } + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const bool per_channel = (quant_type == MLAS_KV_QUANT_TYPE::S8_PerChannel || + quant_type == MLAS_KV_QUANT_TYPE::S4_PerChannel); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = kv_sequence_length * head_size; + const size_t past_buff_chunk_bytes = SafeInt(seqlen_past_kv_cache) * packed_row_bytes; + const size_t present_buff_chunk_bytes = SafeInt(seqlen_present_kv_cache) * packed_row_bytes; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_bytes); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_bytes); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast(past_buff_chunk_bytes + kv_sequence_length * head_size * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_bytes); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_bytes = past_seqlen * packed_row_bytes; + + const float* head_k_scale = per_channel + ? k_scale + kv_head_index * head_size + : k_scale; + const float* head_v_scale = per_channel + ? v_scale + kv_head_index * head_size + : v_scale; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatQuantStateChunkGQA( + past_key_data, k_new, present_key_data, + present_buff_chunk_bytes, past_buff_chunk_bytes, + past_chunk_bytes, kv_sequence_length, head_size, head_size, + quant_type, head_k_scale, past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatQuantStateChunkGQA( + past_value_data, v_new, present_value_data, + present_buff_chunk_bytes, past_buff_chunk_bytes, + past_chunk_bytes, kv_sequence_length, head_size, head_size, + quant_type, head_v_scale, past_present_share_buffer, kv_idx); } }); } + // ---- Phase 2: Flash Attention with quantized KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention) + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + // For quantized KV: effective bytes per KV element for cache considerations + // We dequantize V blocks to FP32, so working set per KV row = head_size * sizeof(float) + // K is accessed via MlasQKGemm which internally dequantizes; for block sizing purposes + // treat it as FP32 working set. + // + // Working set in L2 per tile: + // Q slice: [Br, head_size] floats + // Scores: [Br, Bc] floats + // V dequant: [Bc, head_size] floats + // Temp output: [Br, head_size] floats + // Total ~ (2*Br + Bc) * head_size + Br * Bc + // Approximation: use same formula as FP32 flash attention + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), we must fall back to per-batch + // invocation so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + // Ragged seqlens: each batch item has its own total_seqlen (and therefore + // past_seqlen). Must use per-batch invocation regardless of past_key/prompt state. + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + // Allocate per-thread buffers for flash attention + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + args.quant_type = quant_type; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionQuantizedKV(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + // For prompt/no-past cases, past_seqlen is 0; otherwise derive from total_sl. + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + args.quant_type = quant_type; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + args.query = Q + static_cast(b) * num_heads_ * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * seqlen_present_kv_cache * packed_row_bytes; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * seqlen_present_kv_cache * packed_row_bytes; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside) + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + batch_bias += static_cast(b) * num_heads_ * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; + + MlasFlashAttentionQuantizedKV(&args, tp); + } + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 4df5f6a349599..1b9e4c3a6a5cd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -294,13 +294,34 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if constexpr (std::is_same_v) { const float* k_data_q = packed_qkv ? nullptr : k_rotary; const float* v_data_q = packed_qkv ? nullptr : V.Get().Data(); + auto mlas_quant_type = ToMlasKVQuantType(k_quant_type_, kv_cache_bit_width_); + + // Use flash attention path when: + // 1. Total sequence length is long enough to benefit from tiling + // 2. No features that flash path doesn't support (softcap, smooth softmax, output_qk) + const bool use_flash = !disable_gqa_flash_ && + parameters.total_sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr; + + if (use_flash) { + return ApplyAttentionQuantizedFlash( + q_rotary, k_data_q, v_data_q, + attention_bias, + past_key, past_value, + output, present_k, present_v, seqlens_k, + k_scale->Data(), v_scale->Data(), + mlas_quant_type, parameters, allocator, context); + } + return ApplyAttentionQuantized( q_rotary, k_data_q, v_data_q, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, k_scale->Data(), v_scale->Data(), - ToMlasKVQuantType(k_quant_type_, kv_cache_bit_width_), - parameters, allocator, context); + mlas_quant_type, parameters, allocator, context); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Quantized KV cache requires float Q dtype"); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index f9d949012e64c..ab0e2d9e01901 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -49,6 +49,7 @@ #include "core/common/common.h" #include "core/common/safeint.h" +#include "core/providers/cuda/cu_inc/cub.cuh" #include "contrib_ops/cuda/llm/common/logger.h" #include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" #include "contrib_ops/cuda/llm/common/data_type.h" @@ -63,7 +64,6 @@ #include "contrib_ops/cuda/llm/moe_gemm/moe_gemm_activation_kernels.cuh" #include "contrib_ops/cuda/llm/moe_gemm/moe_gemm_utils.cuh" -#include #include #include diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu index 28fd4fb1516fb..61cdf3ab23fca 100644 --- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu @@ -5,9 +5,9 @@ #include "contrib_ops/cuda/moe/qmoe_kernels.h" #include "core/common/narrow.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/cub.cuh" #include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h" #include -#include #include #include diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 275fa837a7257..e9775fe23fe08 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1565,6 +1565,23 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p } #endif +// Backstop validation for callers that load external data outside Graph::Resolve (e.g. training +// checkpoints, custom-op initializers). Passes through ORT's in-memory address markers — those are +// validated at higher layers (Graph::ConvertInitializersIntoOrtValues for dense; markers on sparse +// sub-tensors are rejected outright in SparseTensorProtoToDenseTensorProto). For declared file paths, +// defers to ValidateExternalDataPath, which rejects absolute paths and paths that escape the model +// directory. Callers must have already verified the tensor has external data. +static Status ValidateExternalFilePathForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path) { + if (HasExternalDataInMemory(tensor_proto)) { + return Status::OK(); + } + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + return utils::ValidateExternalDataPath(model_path, external_data_info->GetRelPath()); +} + Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, @@ -1572,6 +1589,11 @@ Status GetExtDataFromTensorProto(const Env& env, ORT_ENFORCE(HasExternalData(tensor_proto), "TensorProto for: ", tensor_proto.name(), "Expected to have external data"); + // Defense-in-depth: reject absolute or directory-escaping external data paths even when this + // function is reached outside Graph::Resolve (e.g. training checkpoint load, custom-op init). + // In-memory address markers are passed through; their validity is enforced upstream. + ORT_RETURN_IF_ERROR(ValidateExternalFilePathForTensor(tensor_proto, model_path)); + std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -1735,6 +1757,9 @@ Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem: const IExternalDataLoader& ext_data_loader, Tensor& tensor) { ORT_ENFORCE(HasExternalData(tensor_proto)); + // Defense-in-depth path validation for callers reaching this function outside Graph::Resolve. + // In-memory markers are passed through; rejected explicitly below as unsupported for this path. + ORT_RETURN_IF_ERROR(ValidateExternalFilePathForTensor(tensor_proto, model_path)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -2098,30 +2123,29 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { #if !defined(DISABLE_SPARSE_TENSORS) -// Validates that a TensorProto's external data path does not escape the model directory. -// Also validates that the file exists when filesystem access is available (skipped on WASM without a virtual FS). -// Returns Status::OK() (no-op) for tensors that do not use file-based external data. -static Status ValidateExternalDataPathForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& model_path) { - // Gates on data_location == EXTERNAL directly instead of using HasExternalData()/HasExternalDataInFile(), - // which also require data_type != UNDEFINED. That check is appropriate for data processing (can't unpack - // without a type), but too narrow for security validation: we must validate any declared external path - // regardless of data_type. - if (tensor_proto.data_location() != ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { +// Validates the external data declaration on a sub-tensor of a SparseTensorProto (values or +// indices). Validates that any file path stays within the model directory. +// +// Gates on data_location == EXTERNAL (rather than HasExternalData()) so that path validation +// runs even when data_type is UNDEFINED. A malicious model could set data_location=EXTERNAL with +// data_type=UNDEFINED and an evil file path; downstream loading would also reject it, but we +// validate here for defense-in-depth. +// +// In-memory address markers must never appear on sparse sub-tensors. The trusted .ort loader +// materializes sparse sub-tensors as inline raw_data (see LoadSparseInitializerOrtFormat); the +// untrusted .onnx protobuf path rejects markers at the Graph constructor; and +// SparseTensorProtoToDenseTensorProto re-asserts the invariant before this function is reached. +// The HasExternalDataInMemory early-return below is a paranoid backstop. +static Status ValidateSparseSubTensorExternalDataPath(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path) { + if (tensor_proto.data_location() != ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL || + HasExternalDataInMemory(tensor_proto)) { return Status::OK(); } std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - const auto& rel_path = external_data_info->GetRelPath(); - - // In-memory external data uses special marker locations — skip file path validation for those. - if (rel_path == kTensorProtoLittleEndianMemoryAddressTag || - rel_path == kTensorProtoNativeEndianMemoryAddressTag) { - return Status::OK(); - } - - return utils::ValidateExternalDataPath(model_path, rel_path); + return utils::ValidateExternalDataPath(model_path, external_data_info->GetRelPath()); } static Status CopySparseData(const std::string& name, @@ -2303,6 +2327,23 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT const auto& sparse_values = sparse.values(); const auto& name = sparse_values.name(); + // In-memory address markers (pointing into mmap'd / heap buffers) are forbidden on sparse + // sub-tensors. The trusted .ort loader is required to materialize sparse sub-tensors as inline + // raw_data (see LoadSparseInitializerOrtFormat) so they never carry markers. Untrusted .onnx + // protobuf input is rejected at the Graph constructor before reaching this function; this is + // the function-level backstop. A marker here would otherwise trigger an arbitrary memory read + // in UnpackInitializerData. + if (HasExternalDataInMemory(sparse_values)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " values use an in-memory address marker which is not permitted on sparse sub-tensors."); + } + if (HasExternalDataInMemory(sparse.indices())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " indices use an in-memory address marker which is not permitted on sparse sub-tensors."); + } + const auto values_rank = sparse_values.dims_size(); if (values_rank != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, @@ -2371,8 +2412,8 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT // Validate external data paths before any early returns or allocations. // This ensures malicious paths are rejected even for zero-element tensors, // and prevents large allocations before an invalid path is caught. - ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(sparse_values, model_path)); - ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(indices, model_path)); + ORT_RETURN_IF_ERROR(ValidateSparseSubTensorExternalDataPath(sparse_values, model_path)); + ORT_RETURN_IF_ERROR(ValidateSparseSubTensorExternalDataPath(indices, model_path)); if (dense_elements == 0) { // if there are no elements in the dense tensor, we can return early with an empty tensor proto diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 30df0e23af6ae..24d49f5f3f247 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1265,6 +1265,21 @@ Graph::Graph(const Model& owning_model, continue; } +#if !defined(DISABLE_SPARSE_TENSORS) + // Reject ORT in-memory address markers on a sparse-tensor Constant attribute before the + // sparse-to-dense conversion runs — those markers are an in-process ORT sentinel and must + // never appear in a deserialized protobuf. See note on the dense initializer loop below. + if (node.attribute_size() > 0 && + node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + const auto& s = node.attribute(0).sparse_tensor(); + ORT_ENFORCE(!utils::HasExternalDataInMemory(s.values()) && + !utils::HasExternalDataInMemory(s.indices()), + "Constant node '", node.name(), + "' sparse-tensor attribute references an ORT in-memory address marker, " + "which is not allowed in a model protobuf."); + } +#endif + const gsl::not_null tensor{graph_proto_->add_initializer()}; ORT_THROW_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor)); @@ -1304,6 +1319,16 @@ Graph::Graph(const Model& owning_model, if (graph_proto_->sparse_initializer_size() > 0) { for (const auto& sparse_tensor : graph_proto_->sparse_initializer()) { ORT_ENFORCE(utils::HasName(sparse_tensor), "Sparse initializer must have a name. This model is invalid"); + // Reject ORT's in-memory address markers on sparse sub-tensors arriving via the protobuf + // path. Such markers are an internal ORT optimization set by trusted loaders (e.g. ORT-format + // flatbuffer load) and must never appear in a SparseTensorProto deserialized from an .onnx + // protobuf; if they do, the model is crafted and would cause ORT to dereference an + // attacker-supplied pointer during sparse-to-dense conversion. + for (const auto* sub : {&sparse_tensor.values(), &sparse_tensor.indices()}) { + ORT_ENFORCE(!utils::HasExternalDataInMemory(*sub), + "Sparse initializer '", sparse_tensor.values().name(), + "' references an ORT in-memory address marker, which is not allowed in a model protobuf."); + } const gsl::not_null tensor{graph_proto_->add_initializer()}; auto status = utils::SparseTensorProtoToDenseTensorProto(sparse_tensor, model_path, *tensor); ORT_ENFORCE(status.IsOK(), status.ToString()); @@ -1345,6 +1370,14 @@ Graph::Graph(const Model& owning_model, // Copy initial tensors to a map. for (auto& tensor : graph_proto_->initializer()) { + // ORT in-memory address markers are an in-process sentinel: they can only be planted by ORT + // itself (e.g. when constructing a TensorProto that aliases an mmap'd .ort buffer or an OrtValue). + // They must never appear in a TensorProto deserialized from an .onnx protobuf — if they do, the + // model is crafted and would cause ORT to dereference an attacker-supplied pointer when + // resolving the initializer. + ORT_ENFORCE(!utils::HasExternalDataInMemory(tensor), + "Initializer '", tensor.name(), + "' references an ORT in-memory address marker, which is not allowed in a model protobuf."); auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name() diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index c51f24229f145..0fe021cec88d3 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -418,17 +418,27 @@ Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor SparseTensorProto& initializer, const OrtFormatLoadOptions& load_options) { SparseTensorProto loaded_initializer; + + // Sparse sub-tensors must never carry the in-memory address marker. The marker would point into + // the mmap'd flatbuffer buffer; allowing it here would force every downstream consumer of the + // sparse->dense conversion to validate the marker, and would conflate the trust boundary + // (sparse markers from untrusted .onnx input are an arbitrary-memory-read vector). Force the + // inner loader to materialize a normal inline raw_data copy regardless of size; the cost is + // small because sparse->dense conversion immediately copies the bytes again. + OrtFormatLoadOptions sub_tensor_options = load_options; + sub_tensor_options.can_use_flatbuffer_for_initializers = false; + auto fbs_values_tensor = fbs_sparse_tensor.values(); ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model."); auto* values_tensor = loaded_initializer.mutable_values(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, sub_tensor_options)); ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model."); auto fbs_indicies_tensor = fbs_sparse_tensor.indices(); ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'", "Invalid ORT format model."); auto* indicies_tensor = loaded_initializer.mutable_indices(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, sub_tensor_options)); auto fbs_dims = fbs_sparse_tensor.dims(); ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'", diff --git a/onnxruntime/core/mlas/inc/mlas_qkv_quant.h b/onnxruntime/core/mlas/inc/mlas_qkv_quant.h index ed1a9bfdbfba6..f6a5a48e6ccc7 100644 --- a/onnxruntime/core/mlas/inc/mlas_qkv_quant.h +++ b/onnxruntime/core/mlas/inc/mlas_qkv_quant.h @@ -199,7 +199,7 @@ MlasQKGemm( /** * @brief Softmax-times-V GEMM with a quantized V cache. * - * C[M, N] = A[M, K] * B[K, N] + * C[M, N] = Beta * C[M, N] + A[M, K] * B[K, N] * * where: * - A is FP32 row-major, shape [M, K] (attention probabilities), stride lda. @@ -207,8 +207,8 @@ MlasQKGemm( * with K = total_sequence_length, N = head_size), packed row-major over * rows. Each row occupies * MlasKVQuantPackedRowBytes(QuantType, N) bytes. - * - C is FP32 row-major, shape [M, N], stride ldc (>= N). The kernel - * overwrites C (no accumulate). + * - C is FP32 row-major, shape [M, N], stride ldc (>= N). + * When Beta == 0, C is overwritten. When Beta != 0, C is accumulated. * - PER_CHANNEL scales are length N and apply along the N (head_size) axis. * * @param M Query token count. @@ -221,6 +221,7 @@ MlasQKGemm( * @param Scales Scale buffer (single scalar or length-N vector). * @param C Output buffer (FP32). * @param ldc Leading dimension of C in elements. + * @param Beta Scalar multiplier for existing C values. 0 = overwrite. * @param ThreadPool Optional thread pool. */ void @@ -236,5 +237,73 @@ MlasSVGemm( const float* Scales, float* C, size_t ldc, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief Arguments for the Flash Attention kernel with quantized KV cache. + * + * This kernel implements the online-softmax tiled Flash Attention algorithm + * operating directly on INT8/INT4 quantized K and V cache buffers. + * It avoids materializing the full [S, T] attention probability matrix. + */ +struct MlasFlashAttentionQuantizedKVArgs { + int batch_size; + int num_heads; // Q heads + int kv_num_heads; // KV heads (for GQA sharing) + int sequence_length; // Q sequence length (new tokens) + int total_seqlen; // Total KV sequence length (past + new) + int head_size; + int past_seqlen; // For computing causal positions + int local_window_size; // -1 = disabled + int seqlen_present_kv; // Buffer dimension for present KV (may be > total_seqlen) + int q_block_size; // Br (query block size) + int kv_block_size; // Bc (KV block size) + float scale; // 1/sqrt(head_size) or user-specified + + MLAS_KV_QUANT_TYPE quant_type; + bool per_channel_k; // Whether K uses per-channel scales + bool per_channel_v; // Whether V uses per-channel scales + + int thread_count; + float* buffer; + size_t buffer_size_per_thread; + + const float* query; // [B, N, S, H] FP32 + const uint8_t* k_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized + const uint8_t* v_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized + const float* k_scale; // Scalar or per-channel scales for K + const float* v_scale; // Scalar or per-channel scales for V + float* output; // [B, S, N, H] FP32 + + // Attention bias (additive, applied after QK GEMM before masking/softmax). + // Shape: [B|1, N|1, S, T] where dimensions of size 1 are broadcast. + const float* attention_bias; // nullptr if no bias + int attention_bias_seqlen_stride; // stride along the T (total_seqlen) dimension = shape[3] + bool attention_bias_broadcast_batch; // true if shape[0] == 1 + bool attention_bias_broadcast_head; // true if shape[1] == 1 + + // Flash decoding fields (used when sequence_length == 1 and KV is split across threads). + // Partials buffer stores per-(batch, head, kv_chunk) intermediate results: + // [m_partial, l_partial, output_partial[head_size]] for each chunk. + float* flash_decoding_partials; // nullptr to disable flash decoding + int kv_chunk_count; // number of KV chunks = ceil(total_seqlen / kv_block_size) +}; + +/** + * @brief Flash Attention with quantized KV cache. + * + * Implements tiled attention with online softmax, processing KV in blocks + * to avoid materializing the full attention matrix. Supports causal masking + * and local window attention. + * + * @param args Pointer to argument structure. + * @param ThreadPool Optional thread pool for parallelization. + */ +void +MLASCALL +MlasFlashAttentionQuantizedKV( + MlasFlashAttentionQuantizedKVArgs* args, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/flashattn_qkv.cpp b/onnxruntime/core/mlas/lib/flashattn_qkv.cpp new file mode 100644 index 0000000000000..364011fe26e26 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_qkv.cpp @@ -0,0 +1,622 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_qkv.cpp + +Abstract: + + Flash Attention kernel for quantized KV cache (INT8/INT4). + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate + on quantized K/V buffers using MlasQKGemm (for Q×K^T) and + MlasSVGemm with Beta=1.0 (for fused dequant + S×V accumulation). + + Supports causal masking and local window attention. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" +#include "mlas_qkv_quant.h" + +void +MlasFlashAttentionQuantizedKVThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + const MLAS_KV_QUANT_TYPE quant_type = args->quant_type; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, static_cast(head_size)); + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // Pointers into quantized K/V caches + // K cache layout: [batch, kv_num_heads, seqlen_present, packed_head_bytes] + const size_t k_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* k_cache_head = args->k_cache + k_batch_head_offset; + + const size_t v_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* v_cache_head = args->v_cache + v_batch_head_offset; + + // K/V scale pointers + const float* head_k_scale = args->per_channel_k + ? args->k_scale + kv_head_idx * static_cast(head_size) + : args->k_scale; + const float* head_v_scale = args->per_channel_v + ? args->v_scale + kv_head_idx * static_cast(head_size) + : args->v_scale; + + // Q pointer: layout [batch, num_heads, seq, head_size] or packed + const float* q_ptr = args->query + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with quantized K block + // K cache at row offset ir: pointer arithmetic on packed rows + const uint8_t* k_block = k_cache_head + static_cast(ir) * packed_row_bytes; + + MlasQKGemm( + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // Alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (quantized K block) + quant_type, + head_k_scale, + scores, // C (output scores) + row_size_kv, // ldc + nullptr // no thread pool (already threaded) + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + static_cast(num_heads) * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so SVGemm contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block using fused dequant+GEMM + const uint8_t* v_block = v_cache_head + static_cast(ir) * packed_row_bytes; + MlasSVGemm( + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (quantized V block) + quant_type, + head_v_scale, + temp_output, // C (accumulated output) + static_cast(head_size), // ldc + 1.0f, // Beta (accumulate) + nullptr // no thread pool (already threaded) + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingQuantizedKVThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + const MLAS_KV_QUANT_TYPE quant_type = args->quant_type; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, static_cast(head_size)); + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t k_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* k_cache_head = args->k_cache + k_batch_head_offset; + + const size_t v_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* v_cache_head = args->v_cache + v_batch_head_offset; + + // K/V scale pointers + const float* head_k_scale = args->per_channel_k + ? args->k_scale + kv_head_idx * static_cast(head_size) + : args->k_scale; + const float* head_v_scale = args->per_channel_v + ? args->v_scale + kv_head_idx * static_cast(head_size) + : args->v_scale; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1) + const float* q_ptr = args->query + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + // Step 1: QK^T GEMM for this KV chunk + const uint8_t* k_block = k_cache_head + static_cast(ir) * packed_row_bytes; + MlasQKGemm( + 1, // M (single query row) + row_size_kv, // N + static_cast(head_size), // K + scale, // Alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (quantized K block) + quant_type, + head_k_scale, + scores, // C (output scores) + row_size_kv, // ldc + nullptr + ); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + static_cast(num_heads) * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output + const uint8_t* v_block = v_cache_head + static_cast(ir) * packed_row_bytes; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + MlasSVGemm( + 1, // M + static_cast(head_size), // N + row_size_kv, // K + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (quantized V block) + quant_type, + head_v_scale, + partial_output, // C (output for this chunk) + static_cast(head_size), // ldc + 0.0f, // Beta=0 (overwrite) + nullptr + ); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + // Use the output location directly for accumulation + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +void +MLASCALL +MlasFlashAttentionQuantizedKV( + MlasFlashAttentionQuantizedKVArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { + // Flash decoding: two-phase approach. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingQuantizedKVThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + MlasExecuteThreaded( + MlasFlashAttentionQuantizedKVThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } +} diff --git a/onnxruntime/core/mlas/lib/qkv_quant.cpp b/onnxruntime/core/mlas/lib/qkv_quant.cpp index 81fba6bc7cec4..c414324a0493f 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant.cpp @@ -356,14 +356,23 @@ MlasSVGemm( const float* Scales, float* C, size_t ldc, + float Beta, MLAS_THREADPOOL* ThreadPool) { if (M == 0 || N == 0) { return; } if (K == 0) { - for (size_t m = 0; m < M; ++m) { - std::memset(C + m * ldc, 0, N * sizeof(float)); + if (Beta == 0.0f) { + for (size_t m = 0; m < M; ++m) { + std::memset(C + m * ldc, 0, N * sizeof(float)); + } + } else if (Beta != 1.0f) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + C[m * ldc + n] *= Beta; + } + } } return; } @@ -373,7 +382,7 @@ MlasSVGemm( // const auto* Dispatch = GetMlasPlatform().KVQuantGemmDispatch; if (Dispatch != nullptr && Dispatch->SVGemm != nullptr) { - Dispatch->SVGemm(M, N, K, A, lda, B, QuantType, Scales, C, ldc); + Dispatch->SVGemm(M, N, K, A, lda, B, QuantType, Scales, C, ldc, Beta); return; } @@ -393,7 +402,13 @@ MlasSVGemm( const size_t m = static_cast(m_idx); const float* a_row = A + m * lda; float* c_row = C + m * ldc; - std::memset(c_row, 0, N * sizeof(float)); + if (Beta == 0.0f) { + std::memset(c_row, 0, N * sizeof(float)); + } else if (Beta != 1.0f) { + for (size_t n = 0; n < N; ++n) { + c_row[n] *= Beta; + } + } // Per-row scratch for one dequantized B row of length N. float b_dequant[1024]; diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel.h b/onnxruntime/core/mlas/lib/qkv_quant_kernel.h index 5c4e93bb334c3..ebd990703472d 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel.h +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel.h @@ -53,7 +53,7 @@ struct MLAS_KV_QUANT_GEMM_DISPATCH { QKGemm_Fn* QKGemm = nullptr; /** - * S*V GEMM kernel: C[M,N] = A[M,K] * B[K,N] + * S*V GEMM kernel: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] * * B is quantized (INT8 or INT4), logically [K, N] in packed row-major. */ @@ -67,7 +67,8 @@ struct MLAS_KV_QUANT_GEMM_DISPATCH { MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc + size_t ldc, + float Beta ); SVGemm_Fn* SVGemm = nullptr; diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index 8bec2d350afa5..d7bb01deec2ed 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -268,7 +268,7 @@ QKGemm_Avx2( } // -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // B is [K,N] packed row-major. // // Fused approach: dequantize each B[k,:] element directly into the FMA with @@ -285,7 +285,8 @@ SVGemm_Avx2( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -298,13 +299,26 @@ SVGemm_Avx2( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output - size_t n = 0; - for (; n < vec_end_n; n += 8) { - _mm256_storeu_ps(c_row + n, _mm256_setzero_ps()); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 8) { + _mm256_storeu_ps(c_row + n, _mm256_setzero_ps()); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + __m256 beta_vec = _mm256_broadcast_ss(&Beta); + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_mul_ps(c_vec, beta_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } if (!int4) { @@ -315,7 +329,7 @@ SVGemm_Avx2( const float a_val = a_row[k]; __m256 a_broadcast = _mm256_broadcast_ss(&a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 8) { __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); __m256i i32 = _mm256_cvtepi8_epi32(raw); @@ -331,35 +345,59 @@ SVGemm_Avx2( } } } else { - // Per-tensor: accumulate unscaled dot products, then scale the output row once. - for (size_t k = 0; k < K; ++k) { - const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); - const float a_val = a_row[k]; - __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + // Per-tensor: when Beta==0, accumulate unscaled then scale once at end. + // When Beta!=0, C already has scaled values so fold scale into a_val. + if (Beta == 0.0f) { + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k]; + __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); + __m256i i32 = _mm256_cvtepi8_epi32(raw); + __m256 bf = _mm256_cvtepi32_ps(i32); + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } + } - n = 0; + __m256 scale_vec = _mm256_broadcast_ss(Scales); + size_t n = 0; for (; n < vec_end_n; n += 8) { - __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); - __m256i i32 = _mm256_cvtepi8_epi32(raw); - __m256 bf = _mm256_cvtepi32_ps(i32); __m256 c_vec = _mm256_loadu_ps(c_row + n); - c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + c_vec = _mm256_mul_ps(c_vec, scale_vec); _mm256_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]); + c_row[n] *= Scales[0]; + } + } else { + // Beta!=0: fold scale into a_val to avoid separate pass + const float tensor_scale = Scales[0]; + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k] * tensor_scale; + __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); + __m256i i32 = _mm256_cvtepi8_epi32(raw); + __m256 bf = _mm256_cvtepi32_ps(i32); + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } } - } - - __m256 scale_vec = _mm256_broadcast_ss(Scales); - n = 0; - for (; n < vec_end_n; n += 8) { - __m256 c_vec = _mm256_loadu_ps(c_row + n); - c_vec = _mm256_mul_ps(c_vec, scale_vec); - _mm256_storeu_ps(c_row + n, c_vec); - } - for (; n < N; ++n) { - c_row[n] *= Scales[0]; } } } else { @@ -369,7 +407,7 @@ SVGemm_Avx2( const float a_val = a_row[k]; __m256 a_broadcast = _mm256_broadcast_ss(&a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 8) { __m256 bf = DequantInt4x8(b_row, n, per_channel, Scales); __m256 c_vec = _mm256_loadu_ps(c_row + n); diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index fa5aff0165897..16e82f19c3711 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -512,7 +512,7 @@ QKGemm_Avx512Vnni( } // ============================================================================ -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // B is [K,N] packed row-major. // // For SVGemm, A is attention weights (FP32) and B is V-cache (quantized). @@ -532,7 +532,8 @@ SVGemm_Avx512Vnni( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -545,13 +546,26 @@ SVGemm_Avx512Vnni( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output using 512-bit stores - size_t n = 0; - for (; n < vec_end_n; n += 16) { - _mm512_storeu_ps(c_row + n, _mm512_setzero_ps()); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 16) { + _mm512_storeu_ps(c_row + n, _mm512_setzero_ps()); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + __m512 beta_vec = _mm512_set1_ps(Beta); + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_mul_ps(c_vec, beta_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } if (!int4) { @@ -562,7 +576,7 @@ SVGemm_Avx512Vnni( const float a_val = a_row[k]; __m512 a_broadcast = _mm512_set1_ps(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 16) { __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); __m512i i32 = _mm512_cvtepi8_epi32(raw); @@ -578,35 +592,59 @@ SVGemm_Avx512Vnni( } } } else { - // Per-tensor: accumulate unscaled dot products, then scale the output row once. - for (size_t k = 0; k < K; ++k) { - const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); - const float a_val = a_row[k]; - __m512 a_broadcast = _mm512_set1_ps(a_val); + // Per-tensor: when Beta==0, accumulate unscaled then scale once at end. + // When Beta!=0, fold scale into a_val. + if (Beta == 0.0f) { + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k]; + __m512 a_broadcast = _mm512_set1_ps(a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); + __m512i i32 = _mm512_cvtepi8_epi32(raw); + __m512 bf = _mm512_cvtepi32_ps(i32); + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } + } - n = 0; + __m512 scale_vec = _mm512_set1_ps(Scales[0]); + size_t n = 0; for (; n < vec_end_n; n += 16) { - __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); - __m512i i32 = _mm512_cvtepi8_epi32(raw); - __m512 bf = _mm512_cvtepi32_ps(i32); __m512 c_vec = _mm512_loadu_ps(c_row + n); - c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + c_vec = _mm512_mul_ps(c_vec, scale_vec); _mm512_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]); + c_row[n] *= Scales[0]; + } + } else { + // Beta!=0: fold scale into a_val + const float tensor_scale = Scales[0]; + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k] * tensor_scale; + __m512 a_broadcast = _mm512_set1_ps(a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); + __m512i i32 = _mm512_cvtepi8_epi32(raw); + __m512 bf = _mm512_cvtepi32_ps(i32); + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } } - } - - __m512 scale_vec = _mm512_set1_ps(Scales[0]); - n = 0; - for (; n < vec_end_n; n += 16) { - __m512 c_vec = _mm512_loadu_ps(c_row + n); - c_vec = _mm512_mul_ps(c_vec, scale_vec); - _mm512_storeu_ps(c_row + n, c_vec); - } - for (; n < N; ++n) { - c_row[n] *= Scales[0]; } } } else { @@ -616,7 +654,7 @@ SVGemm_Avx512Vnni( const float a_val = a_row[k]; __m512 a_broadcast = _mm512_set1_ps(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 16) { __m512 bf = DequantInt4x16_Avx512(b_row, n, per_channel, Scales); __m512 c_vec = _mm512_loadu_ps(c_row + n); diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp index 1aabbd8ca39cb..070b1243955cd 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp @@ -244,7 +244,7 @@ QKGemm_Neon( } // -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // void SVGemm_Neon( @@ -257,7 +257,8 @@ SVGemm_Neon( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -277,23 +278,40 @@ SVGemm_Neon( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output - size_t n = 0; - for (; n < vec_end_n; n += 4) { - vst1q_f32(c_row + n, vdupq_n_f32(0.0f)); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 4) { + vst1q_f32(c_row + n, vdupq_n_f32(0.0f)); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + float32x4_t beta_vec = vdupq_n_f32(Beta); + size_t n = 0; + for (; n < vec_end_n; n += 4) { + float32x4_t c_vec = vld1q_f32(c_row + n); + c_vec = vmulq_f32(c_vec, beta_vec); + vst1q_f32(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } + // When Beta != 0 and per-tensor, we must apply scale inline during + // dequantization (can't defer scaling since C already has scaled values). + const bool apply_scale_inline = per_channel || (Beta != 0.0f); + for (size_t k = 0; k < K; ++k) { const uint8_t* b_row_packed = B_bytes + k * row_bytes; - DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel); + DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, apply_scale_inline); const float a_val = a_row[k]; float32x4_t a_broadcast = vdupq_n_f32(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 4) { float32x4_t c_vec = vld1q_f32(c_row + n); float32x4_t b_vec = vld1q_f32(b_buf + n); @@ -305,9 +323,9 @@ SVGemm_Neon( } } - if (!per_channel) { + if (!apply_scale_inline) { const float32x4_t scale_vec = vdupq_n_f32(Scales[0]); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 4) { float32x4_t c_vec = vld1q_f32(c_row + n); c_vec = vmulq_f32(c_vec, scale_vec); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d9b5760848678..6fd53220a0180 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -12,6 +12,7 @@ #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_nhwc_ops.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" @@ -383,23 +384,7 @@ std::optional CUDAExecutionProvider::ShouldConvertDataLayoutForOp([[maybe_ return std::nullopt; } - // TODO(mtavenrath) generate list from registered kernels using nhwc domain - static const std::unordered_set cuda_nhwc_onnx_ops{ - "BatchNormalization", - "Conv", - "ConvTranspose", - "GlobalMaxPool", - "MaxPool", - "GlobalAveragePool", - "AveragePool", - "GridSample", - "DepthToSpace", - "SpaceToDepth", - "LRN", - }; - - return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) || - (node_domain == kMSDomain && node_op_type == "GridSample"); + return cuda::IsNhwcEligible(node_domain, node_op_type); #else // defined(ENABLE_CUDA_NHWC_OPS) ORT_UNUSED_PARAMETER(node_domain); diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h b/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h new file mode 100644 index 0000000000000..e4fe232e2362e --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace cuda { + +// Unified allowlist of ops eligible for NHWC layout conversion in both the +// bundled CUDA EP and the CUDA plugin EP. Maintaining a single source of truth +// prevents silent divergence between the two implementations. + +inline bool IsNhwcEligibleOnnxOp(std::string_view op_type) { + // Alphabetical order for easy maintenance. + return op_type == "AveragePool" || + op_type == "BatchNormalization" || + op_type == "Conv" || + op_type == "ConvTranspose" || + op_type == "DepthToSpace" || + op_type == "GlobalAveragePool" || + op_type == "GlobalMaxPool" || + op_type == "GridSample" || + op_type == "LRN" || + op_type == "MaxPool" || + op_type == "SpaceToDepth"; +} + +inline bool IsNhwcEligibleMsOp(std::string_view op_type) { + return op_type == "GridSample"; +} + +// Returns true if the given (domain, op_type) pair is eligible for NHWC +// conversion. |domain| should be kOnnxDomain ("") or kMSDomain +// ("com.microsoft"). +inline bool IsNhwcEligible(std::string_view domain, std::string_view op_type) { + if (domain.empty()) { + return IsNhwcEligibleOnnxOp(op_type); + } + if (domain == "com.microsoft") { + return IsNhwcEligibleMsOp(op_type); + } + return false; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 6a1a1b8698b4d..1212f8ed77170 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -20,6 +20,7 @@ #include #include "core/graph/constants.h" +#include "core/providers/cuda/cuda_nhwc_ops.h" namespace onnxruntime { namespace cuda_plugin { @@ -214,7 +215,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( tentative_nodes.reserve(all_nodes.size()); for (const auto& node : all_nodes) { - std::string ep_name = node.GetEpName(); + const std::string& ep_name = node.GetEpName(); if (!ep_name.empty()) { if (ep_name == ep->name_) { candidate_nodes.push_back(node); @@ -229,6 +230,18 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( if (kernel_def != nullptr) { candidate_nodes.push_back(node); tentative_nodes.push_back(node); + } else { + // Emit a diagnostic when an NHWC-domain node has no matching kernel. + // This helps identify gaps between the layout conversion allowlist and + // the actually-registered NHWC kernels in the plugin build. + const std::string& node_domain = node.GetDomain(); + if (node_domain == kMSInternalNHWCDomain) { + ORT_CXX_LOGF(Ort::Logger(&ep->logger_), ORT_LOGGING_LEVEL_WARNING, + "NHWC kernel miss: op=%s domain=%s version=%d node=%s - " + "no matching kernel registered in the CUDA plugin EP.", + node.GetOperatorType().c_str(), node_domain.c_str(), + node.GetSinceVersion(), node.GetName().c_str()); + } } } @@ -308,36 +321,11 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( return nullptr; } - // ONNX domain ops that have NHWC kernel registrations. - static const std::unordered_set cuda_nhwc_onnx_ops{ - "BatchNormalization", - "Conv", - "ConvTranspose", - "GlobalMaxPool", - "MaxPool", - "GlobalAveragePool", - "AveragePool", - "GridSample", - "DepthToSpace", - "SpaceToDepth", - "LRN", - }; - - // Check ONNX domain (empty string) or MS domain (com.microsoft) - bool is_onnx_domain = (safe_domain[0] == '\0'); - bool is_ms_domain = (std::strcmp(safe_domain, "com.microsoft") == 0); - - if (is_onnx_domain && cuda_nhwc_onnx_ops.count(safe_op_type) > 0) { + if (cuda::IsNhwcEligible(safe_domain, safe_op_type)) { *should_convert = 1; // Convert - return nullptr; - } - - if (is_ms_domain && std::strcmp(safe_op_type, "GridSample") == 0) { - *should_convert = 1; // Convert - return nullptr; + } else { + *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. } - - *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. return nullptr; #endif } diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index f368537d655b2..a1f7e90d7c089 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -307,19 +307,33 @@ std::unique_ptr MakeComputeCapability(const GraphViewer& grap for (const Node* node : group) { sub_graph->nodes.push_back(node->Index()); - for (const auto* input : node->InputDefs()) { - if (!input->Exists()) { - // skip the placeholder inputs - continue; - } - // if the node input was not produced by this subgraph, add it to the subgraph inputs. - if (!Contains(node_outputs, input)) { - if (!Contains(subgraph_inputs, input)) { - subgraph_inputs.insert(input); - ordered_subgraph_inputs.push_back(input); + // Collect boundary inputs from a def container, skipping placeholders and + // values already produced inside the partition; preserves first-seen order. + auto collect_boundary_inputs = [&](const auto& defs) { + for (const auto* input : defs) { + if (!input->Exists()) { + continue; + } + if (!Contains(node_outputs, input)) { + if (!Contains(subgraph_inputs, input)) { + subgraph_inputs.insert(input); + ordered_subgraph_inputs.push_back(input); + } } } - } + }; + + collect_boundary_inputs(node->InputDefs()); + + // Region-bearing ops (Loop/If/Scan) reference outer-scope SSA values via + // ImplicitInputDefs rather than InputDefs. When an EP claims the whole + // control-flow op, those implicit captures must also be in MetaDef::inputs + // so FinalizeFuseSubGraph can rewire the outer-scope edges onto the fused + // node's InputDefs. Without this, plugin EPs that fuse Loop/If/Scan lose + // the captures at the fused-node boundary and cannot resolve them at + // Compute time. Running this after the explicit loop preserves + // explicit-operand index ordering in meta_def->inputs. + collect_boundary_inputs(node->ImplicitInputDefs()); const auto& output_defs = node->OutputDefs(); for (const auto* output_def : output_defs) { diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index e4376476a885d..136e7d503f59f 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -2,9 +2,12 @@ // Licensed under the MIT License. #include +#include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" +#include "core/platform/env_var.h" #include "core/providers/webgpu/program_manager.h" #include "core/providers/webgpu/shader_helper.h" @@ -18,6 +21,17 @@ ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeli compute_pipeline{compute_pipeline}, shape_uniform_ranks{shape_uniform_ranks} {} +ProgramManager::ProgramManager(WebGpuContext& webgpu_context) + : webgpu_context_{webgpu_context} { + if (std::string dump_file_path = onnxruntime::detail::GetEnvironmentVar("ORT_WEBGPU_EP_SHADER_DUMP_FILE"); + !dump_file_path.empty()) { + auto dump_file = std::make_shared(dump_file_path.c_str(), std::ios::app); + shader_dump_fn_ = [dump_file = std::move(dump_file)](std::string_view shader_content) { + *dump_file << shader_content << "\n"; + }; + } +} + Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -66,9 +80,7 @@ Status ProgramManager::Build(const ProgramBase& program, const ProgramMetadata& program_metadata, const std::span inputs_segments, const std::span outputs_segments, -#ifndef NDEBUG // if debug build const std::string& program_key, -#endif uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, @@ -100,17 +112,24 @@ Status ProgramManager::Build(const ProgramBase& program, std::string code; ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); - LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() -#ifndef NDEBUG // if debug build - << ", Key=\"" << program_key << "\"" -#endif - << "] Start ===\n\n" - << code - << "\n=== WebGPU Shader code [" << program.Name() -#ifndef NDEBUG // if debug build - << ", Key=\"" << program_key << "\"" -#endif - << "] End ===\n"; + // Dump shader code, if requested. It is dumped to `shader_dump_fn_` if set or VERBOSE logging otherwise. + { + const auto shader_content = [&program, &program_key, &code]() { + return MakeString("\n=== WebGPU Shader code [", program.Name(), + ", Key=\"", program_key, "\"", + "] Start ===\n\n", + code, + "\n=== WebGPU Shader code [", program.Name(), + ", Key=\"", program_key, "\"", + "] End ===\n"); + }; + + if (shader_dump_fn_) { + shader_dump_fn_(shader_content()); + } else { + LOGS_DEFAULT(VERBOSE) << shader_content(); + } + } wgpu::ShaderSourceWGSL wgsl_source{}; wgsl_source.code = code.c_str(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 5c4f76d0b4168..afdffe94ea30a 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -3,8 +3,10 @@ #pragma once +#include #include #include +#include #include #include "core/providers/webgpu/webgpu_external_header.h" @@ -36,7 +38,7 @@ class ProgramArtifact { class ProgramManager { public: - ProgramManager(WebGpuContext& webgpu_context) : webgpu_context_(webgpu_context) {} + ProgramManager(WebGpuContext& webgpu_context); Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; Status CalculateSegmentsForInputsAndOutputs(const ProgramBase& program, std::vector& inputs_segments, std::vector& outputs_segments) const; @@ -45,9 +47,7 @@ class ProgramManager { const ProgramMetadata& metadata, const std::span inputs_segments, const std::span outputs_segments, -#ifndef NDEBUG // if debug build const std::string& program_key, -#endif uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, @@ -59,6 +59,8 @@ class ProgramManager { private: std::unordered_map programs_; WebGpuContext& webgpu_context_; + + std::function shader_dump_fn_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ada9a2e8ab692..c7750198ceebc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -303,9 +303,7 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra metadata, inputs_segments, outputs_segments, -#ifndef NDEBUG // if debug build key, -#endif x, y, z, diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index ac8dbfe8f8348..d7c01c2ab8a2d 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -1072,7 +1072,7 @@ def quantize_weight_per_channel( scale_name, zp_name, QuantizedValueType.Initializer, - None, + channel_axis, ) self.quantized_value_map[weight_name] = quantized_value @@ -1097,8 +1097,9 @@ def _dequantize_value(self, value_name): if self.model.model.producer_name != "onnx-quantizer" or ( self.model.model.producer_name == "onnx-quantizer" and scale_init is not None ): - # axis is not specified so scale_init must be a scalar. - assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 + # Per-tensor (axis=None) requires a scalar scale. + if quantized_value.axis is None: + assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) @@ -1109,7 +1110,11 @@ def _dequantize_value(self, value_name): quantized_value.zp_name, ] dequantize_node = onnx.helper.make_node( - "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name + "DequantizeLinear", + dqlinear_inputs, + [value_name], + dqlinear_name, + axis=quantized_value.axis, ) return dequantize_node else: diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 06cc3ea6ad8d2..fa34e9722b66b 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -1151,8 +1151,137 @@ TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_ZeroNNZ) EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); } +// Defense-in-depth: SparseTensorProtoToDenseTensorProto must reject ORT's in-memory address +// marker on sparse sub-tensors unconditionally. The trusted .ort loader is required to +// materialize sparse sub-tensors as inline raw_data so they never carry markers. Without this +// self-check, a caller that bypasses the Graph-ctor chokepoint would dereference an +// attacker-controlled address. +TEST(SparseTensorProtoToDenseTensorProtoMarkerTest, RejectsInMemoryMarkerOnValuesByDefault) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("sparse_marker_values"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = values->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = values->add_external_data(); + off->set_key("offset"); + off->set_value("0"); + auto* len = values->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->add_int64_data(0); + indices->add_int64_data(1); + + ONNX_NAMESPACE::TensorProto dense; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, std::filesystem::path{}, dense); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("in-memory address marker")); +} + +TEST(SparseTensorProtoToDenseTensorProtoMarkerTest, RejectsInMemoryMarkerOnIndicesByDefault) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("sparse_marker_indices"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(1.0f); + values->add_float_data(2.0f); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = indices->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = indices->add_external_data(); + off->set_key("offset"); + off->set_value("0"); + auto* len = indices->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(int64_t))); + + ONNX_NAMESPACE::TensorProto dense; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, std::filesystem::path{}, dense); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("in-memory address marker")); +} + #endif // !defined(DISABLE_SPARSE_TENSORS) +// Defense-in-depth: GetExtDataFromTensorProto must reject absolute external paths even when +// called with an empty model_path (e.g. from training checkpoint or custom-op init paths). +// Previously, ValidateExternalDataPath was only invoked from Graph::ConvertInitializersIntoOrtValues, +// so direct callers of GetExtDataFromTensorProto could load arbitrary files. +TEST(GetExtDataFromTensorProtoTest, RejectsAbsoluteExternalPathWithEmptyModelPath) { + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name("abs_external"); + tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_proto.add_dims(2); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = tensor_proto.add_external_data(); + loc->set_key("location"); +#ifdef _WIN32 + loc->set_value("C:\\data.bin"); +#else + loc->set_value("/etc/passwd"); +#endif + + auto* off = tensor_proto.add_external_data(); + off->set_key("offset"); + off->set_value("0"); + + auto* len = tensor_proto.add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + OrtValue value; + Status status = utils::GetExtDataFromTensorProto(Env::Default(), {}, tensor_proto, value); + ASSERT_FALSE(status.IsOK()) << "Absolute external path must be rejected even with empty model_path."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); +} + +// Defense-in-depth: GetExtDataFromTensorProto must reject directory-escaping external paths even +// when the caller passes a non-empty model_path. This guards callers outside Graph::Resolve. +TEST(GetExtDataFromTensorProtoTest, RejectsEscapingExternalPath) { + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name("escape_external"); + tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_proto.add_dims(2); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = tensor_proto.add_external_data(); + loc->set_key("location"); + loc->set_value("../escape.bin"); + + auto* off = tensor_proto.add_external_data(); + off->set_key("offset"); + off->set_value("0"); + + auto* len = tensor_proto.add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + OrtValue value; + // Pass a synthetic model_path so the validator has a model directory to compare against. + std::filesystem::path model_path = std::filesystem::temp_directory_path() / "sub" / "model.onnx"; + Status status = utils::GetExtDataFromTensorProto(Env::Default(), model_path, tensor_proto, value); + ASSERT_FALSE(status.IsOK()) << "Directory-escaping external path must be rejected."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + TEST(TensorProtoUtilsTest, GetNodeProtoLayeringAnnotation) { // Case 1: Annotation exists { diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 1256a39bcd0c7..019f15a46abc5 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1362,8 +1362,126 @@ TEST_F(GraphTest, UnusedSparseInitializerIsIgnored) { auto& graph_proto = graph2.ToGraphProto(); ASSERT_TRUE(graph_proto.sparse_initializer().empty()); } + +// Regression test for issue #28617: a SparseTensorProto loaded from a model protobuf must not +// be allowed to carry an ORT in-memory address marker on its values or indices sub-tensors. +// Those markers are an ORT-internal mechanism for trusted in-memory buffers (.ort flatbuffer +// load). Accepting them on a crafted .onnx protobuf would let the model make ORT dereference +// an attacker-supplied pointer during sparse-to-dense conversion. +static void RunRejectInMemoryMarkerOnSparseInitializerTest(bool marker_on_indices, + const onnxruntime::logging::Logger& logger) { + Model model("RejectInMemoryMarkerOnSparseInitializer", false, logger); + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + + auto* m_sparse_initializer = m_graph->add_sparse_initializer(); + ConstructSparseTensor("in_memory_marker_sparse", *m_sparse_initializer); + + // Overwrite either values or indices to declare external data pointing at an in-memory marker. + // Allocate a real backing buffer so even an accidental dereference of "offset" stays in-process. + static std::vector backing(64, 0); + auto* sub = marker_on_indices ? m_sparse_initializer->mutable_indices() + : m_sparse_initializer->mutable_values(); + sub->clear_raw_data(); + sub->clear_int64_data(); + sub->clear_float_data(); + sub->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = sub->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = sub->add_external_data(); + off->set_key("offset"); + off->set_value(std::to_string(reinterpret_cast(backing.data()))); + auto* len = sub->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(backing.size())); + + std::string s1; + model_proto.SerializeToString(&s1); + + ModelProto model_proto_1; + ASSERT_TRUE(model_proto_1.ParseFromString(s1)); + + std::shared_ptr p_tmp_model; + // The Graph ctor must reject the marker — Model::Load is expected to return a non-OK status + // (Graph ctor's ORT_THROW is caught at the C++/Status boundary). + ORT_TRY { + auto status = onnxruntime::Model::Load(model_proto_1, p_tmp_model, nullptr, logger); + EXPECT_FALSE(status.IsOK()) << "Loading a model with an in-memory marker on a sparse " + << (marker_on_indices ? "indices" : "values") + << " sub-tensor must fail."; + if (!status.IsOK()) { + EXPECT_THAT(status.ErrorMessage(), + ::testing::HasSubstr("in-memory address marker")); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + EXPECT_THAT(std::string(ex.what()), + ::testing::HasSubstr("in-memory address marker")); + }); + } +} + +TEST_F(GraphTest, RejectInMemoryMarkerOnSparseInitializerValues) { + RunRejectInMemoryMarkerOnSparseInitializerTest(/*marker_on_indices=*/false, *logger_); +} + +TEST_F(GraphTest, RejectInMemoryMarkerOnSparseInitializerIndices) { + RunRejectInMemoryMarkerOnSparseInitializerTest(/*marker_on_indices=*/true, *logger_); +} #endif // !defined(DISABLE_SPARSE_TENSORS) +// Regression test: ORT in-memory address markers are an in-process sentinel only; they must never +// appear in a dense initializer deserialized from an .onnx protobuf. The Graph ctor must reject +// such a model. +TEST_F(GraphTest, RejectInMemoryMarkerOnDenseInitializer) { + Model model("RejectInMemoryMarkerOnDenseInitializer", false, *logger_); + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + + static std::vector backing(64, 0); + + auto* init = m_graph->add_initializer(); + init->set_name("in_memory_marker_dense"); + init->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + init->add_dims(static_cast(backing.size() / sizeof(float))); + init->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = init->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = init->add_external_data(); + off->set_key("offset"); + off->set_value(std::to_string(reinterpret_cast(backing.data()))); + auto* len = init->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(backing.size())); + + std::string s1; + model_proto.SerializeToString(&s1); + + ModelProto model_proto_1; + ASSERT_TRUE(model_proto_1.ParseFromString(s1)); + + std::shared_ptr p_tmp_model; + ORT_TRY { + auto status = onnxruntime::Model::Load(model_proto_1, p_tmp_model, nullptr, *logger_); + EXPECT_FALSE(status.IsOK()) << "Loading a model with an in-memory marker on a dense initializer must fail."; + if (!status.IsOK()) { + EXPECT_THAT(status.ErrorMessage(), + ::testing::HasSubstr("in-memory address marker")); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + EXPECT_THAT(std::string(ex.what()), + ::testing::HasSubstr("in-memory address marker")); + }); + } +} + TEST_F(GraphTest, GraphConstruction_CheckIsNotAcyclic) { // A cyclic graph // SouceNode diff --git a/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp b/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp index 63b6a3eb212d0..23ca591ba6ed2 100644 --- a/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp +++ b/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp @@ -13,6 +13,7 @@ #include "mlas_qkv_quant.h" #include "core/mlas/lib/mlasi.h" #include "core/mlas/lib/qkv_quant_kernel.h" +#include "core/util/thread_utils.h" #include "benchmark/benchmark.h" #include "bench_util.h" @@ -127,10 +128,10 @@ static void BM_SVGemm(benchmark::State& state) { std::vector C(M * N, 0.0f); // Warmup - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } state.SetItemsProcessed(static_cast(state.iterations()) * M * N * K * 2); @@ -225,10 +226,10 @@ static void BM_SVGemm_Scalar(benchmark::State& state) { auto* saved_dispatch = platform.KVQuantGemmDispatch; platform.KVQuantGemmDispatch = nullptr; - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } platform.KVQuantGemmDispatch = saved_dispatch; @@ -305,10 +306,10 @@ static void BM_SVGemm_Avx2(benchmark::State& state) { auto* saved_dispatch = platform.KVQuantGemmDispatch; platform.KVQuantGemmDispatch = &MlasKVQuantGemmDispatchAvx2; - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } platform.KVQuantGemmDispatch = saved_dispatch; @@ -322,3 +323,284 @@ BENCHMARK(BM_QKGemm_Avx2)->Apply(ScalarArgs)->UseRealTime(); BENCHMARK(BM_SVGemm_Avx2)->Apply(ScalarArgs)->UseRealTime(); #endif // MLAS_TARGET_AMD64 || MLAS_TARGET_IX86 + +// +// Flash Attention vs Naive (full materialization) benchmark. +// Compares MlasFlashAttentionQuantizedKV against the manual +// QKGemm + softmax + SVGemm pipeline for realistic GQA shapes. +// +// Args: batch_size, num_heads, kv_num_heads, seq_len, total_seqlen, head_size, QuantType +// + +static MLAS_THREADPOOL* GetBenchThreadPool() { + static OrtThreadPoolParams tpo; + static bool init = [&]() { + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + return true; + }(); + (void)init; + static std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + return tp.get(); +} + +// Naive path: QKGemm + row-wise softmax + SVGemm (full attention matrix materialized) +static void BM_GQA_Naive(benchmark::State& state) { + const int batch_size = static_cast(state.range(0)); + const int num_heads = static_cast(state.range(1)); + const int kv_num_heads = static_cast(state.range(2)); + const int seq_len = static_cast(state.range(3)); + const int total_seqlen = static_cast(state.range(4)); + const int head_size = static_cast(state.range(5)); + const auto qt = static_cast(state.range(6)); + + const int groups = num_heads / kv_num_heads; + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + + // Allocate query [B, N, S, H] + auto query = RandomFloats(static_cast(batch_size) * num_heads * seq_len * head_size, 42); + + // Allocate and quantize K cache [B, kv_N, T, H] + auto k_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 123); + auto v_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 456); + + size_t k_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t v_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t k_cache_size = static_cast(batch_size) * kv_num_heads * total_seqlen * k_row_bytes; + size_t v_cache_size = static_cast(batch_size) * kv_num_heads * total_seqlen * v_row_bytes; + + std::vector k_cache(k_cache_size); + std::vector v_cache(v_cache_size); + + bool per_channel = (qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel); + size_t num_scales = per_channel ? static_cast(kv_num_heads * head_size) : 1; + std::vector k_scale(num_scales, 0.01f); + std::vector v_scale(num_scales, 0.01f); + + // Quantize K and V caches per kv-head + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < kv_num_heads; ++h) { + size_t offset_fp = (static_cast(b) * kv_num_heads + h) * total_seqlen * head_size; + size_t offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * k_row_bytes; + MlasKVQuantize(k_fp.data() + offset_fp, k_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? k_scale.data() + h * head_size : k_scale.data(), nullptr); + offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * v_row_bytes; + MlasKVQuantize(v_fp.data() + offset_fp, v_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? v_scale.data() + h * head_size : v_scale.data(), nullptr); + } + } + + // Allocate working buffers: scores[B*N, S, T] (one per head) + output[B, S, N, H] + std::vector scores(static_cast(batch_size) * num_heads * seq_len * total_seqlen); + std::vector output(static_cast(batch_size) * seq_len * num_heads * head_size, 0.0f); + + auto* tp = GetBenchThreadPool(); + const ptrdiff_t loop_len = batch_size * num_heads; + + for (auto _ : state) { + // Pass 1: QK GEMM + Softmax (matches operator's first TryParallelFor) + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, loop_len, [&](std::ptrdiff_t i) { + const int b = static_cast(i) / num_heads; + const int h = static_cast(i) % num_heads; + const int kv_h = h / groups; + float* my_scores = scores.data() + static_cast(i) * seq_len * total_seqlen; + const float* q_ptr = query.data() + (static_cast(b) * num_heads + h) * seq_len * head_size; + const uint8_t* k_ptr = k_cache.data() + (static_cast(b) * kv_num_heads + kv_h) * total_seqlen * k_row_bytes; + + // QK GEMM: scores[S, T] = scale * Q[S,H] * K[T,H]^T + MlasQKGemm(seq_len, total_seqlen, head_size, scale, + q_ptr, head_size, k_ptr, qt, + per_channel ? k_scale.data() + kv_h * head_size : k_scale.data(), + my_scores, total_seqlen, nullptr); + + // Causal masking + MLAS-optimized softmax (matches operator) + for (int s = 0; s < seq_len; ++s) { + float* row = my_scores + s * total_seqlen; + int valid_len = total_seqlen - seq_len + s + 1; + // Zero out future positions (operator sets them to 0 before softmax) + for (int t = valid_len; t < total_seqlen; ++t) row[t] = 0.f; + // Use MLAS optimized softmax on valid range only + MlasComputeSoftmax(row, row, static_cast(1), + static_cast(valid_len), false, false, 0.0f, nullptr); + } + }); + + // Pass 2: SV GEMM (matches operator's second TryParallelFor) + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, loop_len, [&](std::ptrdiff_t i) { + const int b = static_cast(i) / num_heads; + const int h = static_cast(i) % num_heads; + const int kv_h = h / groups; + float* my_scores = scores.data() + static_cast(i) * seq_len * total_seqlen; + const uint8_t* v_ptr = v_cache.data() + (static_cast(b) * kv_num_heads + kv_h) * total_seqlen * v_row_bytes; + float* out_ptr = output.data() + (static_cast(b) * seq_len * num_heads + h) * head_size; + + // SV GEMM: out[S, H] = scores[S,T] * V[T,H] + MlasSVGemm(seq_len, head_size, total_seqlen, + my_scores, total_seqlen, v_ptr, qt, + per_channel ? v_scale.data() + kv_h * head_size : v_scale.data(), + out_ptr, num_heads * head_size, 0.0f, nullptr); + }); + benchmark::DoNotOptimize(output.data()); + } + + int64_t flops = static_cast(batch_size) * num_heads * seq_len * + (2LL * total_seqlen * head_size + 2LL * total_seqlen * head_size); + state.SetItemsProcessed(static_cast(state.iterations()) * flops); +} + +// Flash path: MlasFlashAttentionQuantizedKV (tiled, online softmax) +static void BM_GQA_Flash(benchmark::State& state) { + const int batch_size = static_cast(state.range(0)); + const int num_heads = static_cast(state.range(1)); + const int kv_num_heads = static_cast(state.range(2)); + const int seq_len = static_cast(state.range(3)); + const int total_seqlen = static_cast(state.range(4)); + const int head_size = static_cast(state.range(5)); + const auto qt = static_cast(state.range(6)); + + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + bool per_channel = (qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel); + + // Allocate query [B, N, S, H] in BNSH layout + auto query = RandomFloats(static_cast(batch_size) * num_heads * seq_len * head_size, 42); + + // Allocate and quantize K/V caches + auto k_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 123); + auto v_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 456); + + size_t k_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t v_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + std::vector k_cache(static_cast(batch_size) * kv_num_heads * total_seqlen * k_row_bytes); + std::vector v_cache(static_cast(batch_size) * kv_num_heads * total_seqlen * v_row_bytes); + + size_t num_scales = per_channel ? static_cast(kv_num_heads * head_size) : 1; + std::vector k_scale(num_scales, 0.01f); + std::vector v_scale(num_scales, 0.01f); + + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < kv_num_heads; ++h) { + size_t offset_fp = (static_cast(b) * kv_num_heads + h) * total_seqlen * head_size; + size_t offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * k_row_bytes; + MlasKVQuantize(k_fp.data() + offset_fp, k_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? k_scale.data() + h * head_size : k_scale.data(), nullptr); + offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * v_row_bytes; + MlasKVQuantize(v_fp.data() + offset_fp, v_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? v_scale.data() + h * head_size : v_scale.data(), nullptr); + } + } + + // Output [B, S, N, H] + std::vector output(static_cast(batch_size) * seq_len * num_heads * head_size, 0.0f); + + // Fixed block sizes for reproducible benchmarks (operator computes from L2 cache size) + int q_block_size = 64; + int kv_block_size = 256; + + // Thread pool + auto* tp = GetBenchThreadPool(); + int thread_count = 8; + + // Flash decoding: for decode (seq_len=1), partition KV across threads + int kv_chunk_count = (total_seqlen + kv_block_size - 1) / kv_block_size; + bool use_flash_decoding = (seq_len == 1 && + batch_size * num_heads < thread_count && + kv_chunk_count > 1); + + // Working buffer + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + partials_buffer_bytes = static_cast(batch_size) * num_heads * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_floats = (buffer_size_per_thread * thread_count + partials_buffer_bytes) / sizeof(float); + std::vector buffer(total_buffer_floats); + float* partials_ptr = use_flash_decoding + ? buffer.data() + (buffer_size_per_thread * thread_count) / sizeof(float) + : nullptr; + + MlasFlashAttentionQuantizedKVArgs args{}; + args.batch_size = batch_size; + args.num_heads = num_heads; + args.kv_num_heads = kv_num_heads; + args.sequence_length = seq_len; + args.total_seqlen = total_seqlen; + args.head_size = head_size; + args.past_seqlen = total_seqlen - seq_len; + args.local_window_size = -1; + args.seqlen_present_kv = total_seqlen; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = qt; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = buffer.data(); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = query.data(); + args.k_cache = k_cache.data(); + args.v_cache = v_cache.data(); + args.k_scale = k_scale.data(); + args.v_scale = v_scale.data(); + args.output = output.data(); + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + // Warmup + MlasFlashAttentionQuantizedKV(&args, tp); + + for (auto _ : state) { + MlasFlashAttentionQuantizedKV(&args, tp); + benchmark::DoNotOptimize(output.data()); + } + + int64_t flops = static_cast(batch_size) * num_heads * seq_len * + (2LL * total_seqlen * head_size + 2LL * total_seqlen * head_size); + state.SetItemsProcessed(static_cast(state.iterations()) * flops); +} + +// Flash vs Naive benchmark configurations +// Args: batch, num_heads, kv_num_heads, seq_len, total_seqlen, head_size, QuantType +static void FlashGQAArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"B", "N", "N_kv", "S", "T", "H", "QType"}); + // INT8 per-tensor (qt=0), INT8 per-channel (qt=1) + for (int qt : {0, 1}) { + // Prompt (prefill): seq_len = total_seqlen + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 16, 8, T, T, 128, qt}); // B=1, GQA ratio 2 + } + // Decode: seq_len=1, past grows + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 16, 8, 1, T, 128, qt}); // B=1, decode + } + // Larger batch decode + b->Args({4, 16, 8, 1, 2048, 128, qt}); + // Flash decoding cases: B*N < thread_count (8), triggers KV partitioning + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 4, 4, 1, T, 128, qt}); // B=1, N=4, flash decoding enabled + } + } +} + +BENCHMARK(BM_GQA_Naive)->Apply(FlashGQAArgs)->UseRealTime(); +BENCHMARK(BM_GQA_Flash)->Apply(FlashGQAArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp b/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp index be13d4b489115..5f0b18fa2cac8 100644 --- a/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp +++ b/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp @@ -251,7 +251,7 @@ class MlasKVQuantTest : public MlasTestBase { RefSVGemm(A, BDequant, CRef, M, N, K, K, N); // Quantized: MlasSVGemm - MlasSVGemm(M, N, K, A, K, BQuant, QuantType, scales, C, N, nullptr); + MlasSVGemm(M, N, K, A, K, BQuant, QuantType, scales, C, N, 0.0f, nullptr); float atol = IsInt4(QuantType) ? 0.15f : 0.02f; float rtol = IsInt4(QuantType) ? 0.1f : 0.01f; @@ -322,3 +322,342 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return count; }); + +// +// Focused test for MlasFlashAttentionQuantizedKV: +// Validates the tiled online-softmax kernel against a naive reference pipeline +// (MlasQKGemm + softmax + MlasSVGemm) across INT8/INT4, per-tensor/per-channel. +// +class MlasFlashAttentionQuantizedKVTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferQ; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputRef; + MatrixGuardBuffer BufferScores; + MatrixGuardBuffer BufferProbs; + MatrixGuardBuffer BufferScalesK; + MatrixGuardBuffer BufferScalesV; + MatrixGuardBuffer BufferKFP32; + MatrixGuardBuffer BufferVFP32; + MatrixGuardBuffer BufferFlash; + MatrixGuardBuffer BufferPartials; + MatrixGuardBuffer BufferKQuant; + MatrixGuardBuffer BufferVQuant; + + void FillRandom(float* buf, size_t n, unsigned seed, float lo = -0.5f, float hi = 0.5f) { + std::default_random_engine gen(seed); + std::uniform_real_distribution dist(lo, hi); + for (size_t i = 0; i < n; i++) { + buf[i] = dist(gen); + } + } + + bool IsInt4(MLAS_KV_QUANT_TYPE qt) { + return qt == MLAS_KV_QUANT_TYPE::S4_PerTensor || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel; + } + + bool IsPerChannel(MLAS_KV_QUANT_TYPE qt) { + return qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel; + } + + void ComputeScales(const float* data, size_t rows, size_t cols, MLAS_KV_QUANT_TYPE qt, float* scales) { + float qmax = IsInt4(qt) ? 7.0f : 127.0f; + if (IsPerChannel(qt)) { + for (size_t c = 0; c < cols; c++) { + float amax = 0.0f; + for (size_t r = 0; r < rows; r++) { + amax = std::max(amax, std::fabs(data[r * cols + c])); + } + scales[c] = (amax > 1e-6f) ? (amax / qmax) : 1.0f; + } + } else { + float amax = 0.0f; + for (size_t i = 0; i < rows * cols; i++) { + amax = std::max(amax, std::fabs(data[i])); + } + scales[0] = (amax > 1e-6f) ? (amax / qmax) : 1.0f; + } + } + + // Naive reference: for a single (batch=1, head=1) attention computation + // Q[seq_len, head_size], K[total_seqlen, head_size], V[total_seqlen, head_size] + // -> output[seq_len, head_size] + // Uses quantized K/V via MlasQKGemm + softmax + MlasSVGemm. + void NaiveReference( + const float* Q, size_t seq_len, size_t total_seqlen, size_t head_size, + const uint8_t* k_quant, const uint8_t* v_quant, + MLAS_KV_QUANT_TYPE quant_type, const float* k_scale, const float* v_scale, + float scale, int past_seqlen, float* output) { + float* scores = BufferScores.GetBuffer(seq_len * total_seqlen); + float* probs = BufferProbs.GetBuffer(seq_len * total_seqlen); + + // QK^T + MlasQKGemm(seq_len, total_seqlen, head_size, scale, + Q, head_size, k_quant, quant_type, k_scale, + scores, total_seqlen, nullptr); + + // Causal mask + softmax + for (size_t q_s = 0; q_s < seq_len; q_s++) { + size_t causal_limit = static_cast(past_seqlen) + q_s + 1; + // Apply causal mask + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + if (kv_s >= causal_limit) { + scores[q_s * total_seqlen + kv_s] = -std::numeric_limits::infinity(); + } + } + // Softmax + float max_val = -std::numeric_limits::infinity(); + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + max_val = std::max(max_val, scores[q_s * total_seqlen + kv_s]); + } + float sum_exp = 0.0f; + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + probs[q_s * total_seqlen + kv_s] = std::exp(scores[q_s * total_seqlen + kv_s] - max_val); + sum_exp += probs[q_s * total_seqlen + kv_s]; + } + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + probs[q_s * total_seqlen + kv_s] /= sum_exp; + } + } + + // SV GEMM + MlasSVGemm(seq_len, head_size, total_seqlen, + probs, total_seqlen, v_quant, quant_type, v_scale, + output, head_size, 0.0f, nullptr); + } + + void TestFlashAttention(size_t seq_len, size_t total_seqlen, size_t head_size, + MLAS_KV_QUANT_TYPE quant_type) { + const size_t k_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t v_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + const int past_seqlen = static_cast(total_seqlen - seq_len); + + // Allocate and fill + float* Q = BufferQ.GetBuffer(seq_len * head_size); + float* K_fp32 = BufferKFP32.GetBuffer(total_seqlen * head_size); + float* V_fp32 = BufferVFP32.GetBuffer(total_seqlen * head_size); + float* k_scale = BufferScalesK.GetBuffer(k_num_scales); + float* v_scale = BufferScalesV.GetBuffer(v_num_scales); + float* output_flash = BufferOutput.GetBuffer(seq_len * head_size); + float* output_ref = BufferOutputRef.GetBuffer(seq_len * head_size); + + unsigned seed = static_cast(seq_len * 1000 + total_seqlen * 10 + head_size); + FillRandom(Q, seq_len * head_size, seed); + FillRandom(K_fp32, total_seqlen * head_size, seed + 1); + FillRandom(V_fp32, total_seqlen * head_size, seed + 2); + + ComputeScales(K_fp32, total_seqlen, head_size, quant_type, k_scale); + ComputeScales(V_fp32, total_seqlen, head_size, quant_type, v_scale); + + // Quantize K and V + uint8_t* k_quant = BufferKQuant.GetBuffer(total_seqlen * packed_row_bytes); + uint8_t* v_quant = BufferVQuant.GetBuffer(total_seqlen * packed_row_bytes); + MlasKVQuantize(K_fp32, k_quant, total_seqlen, head_size, head_size, quant_type, k_scale, nullptr); + MlasKVQuantize(V_fp32, v_quant, total_seqlen, head_size, head_size, quant_type, v_scale, nullptr); + + // Naive reference + NaiveReference(Q, seq_len, total_seqlen, head_size, + k_quant, v_quant, quant_type, k_scale, v_scale, + scale, past_seqlen, output_ref); + + // Flash attention + int q_block_size = std::min(static_cast(seq_len), 16); + int kv_block_size = std::min(static_cast(total_seqlen), 32); + + size_t buffer_size_per_thread = + (static_cast(q_block_size) * 2 + + static_cast(q_block_size) * static_cast(kv_block_size) + + static_cast(q_block_size) * static_cast(head_size)) * + sizeof(float); + float* flash_buffer = BufferFlash.GetBuffer(buffer_size_per_thread / sizeof(float)); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = 1; + args.kv_num_heads = 1; + args.sequence_length = static_cast(seq_len); + args.total_seqlen = static_cast(total_seqlen); + args.head_size = static_cast(head_size); + args.past_seqlen = past_seqlen; + args.local_window_size = -1; + args.seqlen_present_kv = static_cast(total_seqlen); + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = quant_type; + args.per_channel_k = IsPerChannel(quant_type); + args.per_channel_v = IsPerChannel(quant_type); + args.thread_count = 1; + args.buffer = flash_buffer; + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = k_quant; + args.v_cache = v_quant; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output_flash; + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = nullptr; + args.kv_chunk_count = 0; + + MlasFlashAttentionQuantizedKV(&args, nullptr); + + // Compare: flash uses ComputeSumExpF32Kernel (SIMD polynomial approx) while + // NaiveReference uses std::exp. Tolerance accounts for accumulation order + // differences across platforms/ISAs. + float atol = IsInt4(quant_type) ? 1e-3f : 1e-4f; + for (size_t i = 0; i < seq_len * head_size; i++) { + float diff = std::fabs(output_flash[i] - output_ref[i]); + ASSERT_LE(diff, atol) + << "FlashAttention vs Naive mismatch at [" << i / head_size << ", " << i % head_size + << "], flash=" << output_flash[i] << " ref=" << output_ref[i] + << " seq_len=" << seq_len << " total_seqlen=" << total_seqlen + << " head_size=" << head_size + << " qt=" << static_cast(quant_type); + } + } + + // Test flash decoding path: sequence_length=1 with KV split across chunks + void TestFlashDecoding(size_t total_seqlen, size_t head_size, + MLAS_KV_QUANT_TYPE quant_type) { + const size_t seq_len = 1; + const size_t k_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t v_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + const int past_seqlen = static_cast(total_seqlen - 1); + + // Allocate and fill + float* Q = BufferQ.GetBuffer(head_size); + float* K_fp32 = BufferKFP32.GetBuffer(total_seqlen * head_size); + float* V_fp32 = BufferVFP32.GetBuffer(total_seqlen * head_size); + float* k_scale_buf = BufferScalesK.GetBuffer(k_num_scales); + float* v_scale_buf = BufferScalesV.GetBuffer(v_num_scales); + float* output_flash = BufferOutput.GetBuffer(head_size); + float* output_ref = BufferOutputRef.GetBuffer(head_size); + + unsigned seed = static_cast(total_seqlen * 100 + head_size * 7); + FillRandom(Q, head_size, seed); + FillRandom(K_fp32, total_seqlen * head_size, seed + 1); + FillRandom(V_fp32, total_seqlen * head_size, seed + 2); + + ComputeScales(K_fp32, total_seqlen, head_size, quant_type, k_scale_buf); + ComputeScales(V_fp32, total_seqlen, head_size, quant_type, v_scale_buf); + + // Quantize K and V + uint8_t* k_quant = BufferKQuant.GetBuffer(total_seqlen * packed_row_bytes); + uint8_t* v_quant = BufferVQuant.GetBuffer(total_seqlen * packed_row_bytes); + MlasKVQuantize(K_fp32, k_quant, total_seqlen, head_size, head_size, quant_type, k_scale_buf, nullptr); + MlasKVQuantize(V_fp32, v_quant, total_seqlen, head_size, head_size, quant_type, v_scale_buf, nullptr); + + // Naive reference + NaiveReference(Q, seq_len, total_seqlen, head_size, + k_quant, v_quant, quant_type, k_scale_buf, v_scale_buf, + scale, past_seqlen, output_ref); + + // Flash decoding: use small kv_block_size to get multiple chunks + int kv_block_size = std::min(static_cast(total_seqlen), 16); + int kv_chunk_count = (static_cast(total_seqlen) + kv_block_size - 1) / kv_block_size; + + // Per-thread scratch: scores[kv_block_size] + size_t buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + float* flash_buffer = BufferFlash.GetBuffer(buffer_size_per_thread / sizeof(float)); + + // Partials buffer: [1 batch * 1 head * kv_chunk_count * (2 + head_size)] + size_t partials_count = static_cast(kv_chunk_count) * (2 + head_size); + float* partials = BufferPartials.GetBuffer(partials_count); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = 1; + args.kv_num_heads = 1; + args.sequence_length = 1; + args.total_seqlen = static_cast(total_seqlen); + args.head_size = static_cast(head_size); + args.past_seqlen = past_seqlen; + args.local_window_size = -1; + args.seqlen_present_kv = static_cast(total_seqlen); + args.q_block_size = 1; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = quant_type; + args.per_channel_k = IsPerChannel(quant_type); + args.per_channel_v = IsPerChannel(quant_type); + args.thread_count = 1; + args.buffer = flash_buffer; + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = k_quant; + args.v_cache = v_quant; + args.k_scale = k_scale_buf; + args.v_scale = v_scale_buf; + args.output = output_flash; + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = partials; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionQuantizedKV(&args, nullptr); + + // Compare: flash decoding has an extra reduce phase with exp rescaling, + // so tolerance is slightly larger than the single-pass flash attention test. + float atol = IsInt4(quant_type) ? 1e-3f : 1e-4f; + for (size_t i = 0; i < head_size; i++) { + float diff = std::fabs(output_flash[i] - output_ref[i]); + ASSERT_LE(diff, atol) + << "FlashDecoding vs Naive mismatch at [" << i + << "], flash=" << output_flash[i] << " ref=" << output_ref[i] + << " total_seqlen=" << total_seqlen + << " head_size=" << head_size + << " qt=" << static_cast(quant_type); + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("FlashAttentionQuantizedKV"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + static const MLAS_KV_QUANT_TYPE AllQuantTypes[] = { + MLAS_KV_QUANT_TYPE::S8_PerTensor, + MLAS_KV_QUANT_TYPE::S8_PerChannel, + MLAS_KV_QUANT_TYPE::S4_PerTensor, + MLAS_KV_QUANT_TYPE::S4_PerChannel, + }; + + for (auto qt : AllQuantTypes) { + size_t min_head = size_t{4}; + for (size_t seq_len : {size_t{1}, size_t{4}, size_t{16}}) { + for (size_t total_seqlen : {size_t{4}, size_t{32}, size_t{64}}) { + if (total_seqlen < seq_len) continue; + for (size_t head_size : {min_head, size_t{32}, size_t{64}}) { + TestFlashAttention(seq_len, total_seqlen, head_size, qt); + } + } + } + // Flash decoding tests (sequence_length=1 with KV split into chunks) + for (size_t total_seqlen : {size_t{4}, size_t{32}, size_t{64}, size_t{128}}) { + for (size_t head_size : {min_head, size_t{32}, size_t{64}}) { + TestFlashDecoding(total_seqlen, head_size, qt); + } + } + } + } +}; + +static UNUSED_VARIABLE bool added_flash_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/optimizer/initializer_test.cc b/onnxruntime/test/optimizer/initializer_test.cc index 391942acfca35..6f340e9a9b734 100644 --- a/onnxruntime/test/optimizer/initializer_test.cc +++ b/onnxruntime/test/optimizer/initializer_test.cc @@ -96,7 +96,10 @@ TEST(OptimizerInitializerTest, LoadExternalData) { // bad model paths EXPECT_THROW(Initializer i(tensor_proto_base, std::filesystem::path()), OnnxRuntimeException); - EXPECT_THROW(Initializer i(tensor_proto_base, ORT_TSTR("invalid/directory")), std::filesystem::filesystem_error); + // ValidateExternalDataPath in GetExtDataFromTensorProto now rejects this earlier with an + // ORT error ("External data path does not exist") instead of letting a downstream + // std::filesystem call throw filesystem_error. + EXPECT_THROW(Initializer i(tensor_proto_base, ORT_TSTR("invalid/directory")), OnnxRuntimeException); // bad length { diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5f435199679be..75601ef55ffe6 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -208,6 +208,113 @@ TEST(PartitioningUtilsTest, TestQDQNodeGroupWithRedundantRelu) { CheckAllNodesProcessed(build_model); } +// Regression test for the fix that adds Node::ImplicitInputDefs() to MetaDef::inputs +// in utils::MakeComputeCapability. Builds a graph with a Loop whose body captures +// outer-scope tensor "B"; asserts B appears in the fused subgraph's MetaDef::inputs +// and that explicit Loop operands precede the implicit capture. +TEST(PartitioningUtilsTest, TestLoopBodyImplicitInputsInMetaDef) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("loop_capture", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 16}}, {}, logger); + Graph& main_graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto float_2x2; + float_2x2.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_2x2.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + float_2x2.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + ONNX_NAMESPACE::TypeProto int64_scalar; + int64_scalar.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_INT64); + int64_scalar.mutable_tensor_type()->mutable_shape(); + + ONNX_NAMESPACE::TypeProto bool_scalar; + bool_scalar.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_BOOL); + bool_scalar.mutable_tensor_type()->mutable_shape(); + + auto build_body = [&]() -> ONNX_NAMESPACE::GraphProto { + Model body_model("loop_body", true, logger); + Graph& body = body_model.MainGraph(); + + auto& iter = body.GetOrCreateNodeArg("iter", &int64_scalar); + auto& cond_in = body.GetOrCreateNodeArg("cond_in", &bool_scalar); + auto& acc_in = body.GetOrCreateNodeArg("acc_in", &float_2x2); + + // Outer-scope capture B used inside the body Add. + ORT_IGNORE_RETURN_VALUE(body.GetOrCreateNodeArg("B", &float_2x2)); + body.AddOuterScopeNodeArg("B"); + auto& B_in_body = *body.GetNodeArg("B"); + + auto& acc_out = body.GetOrCreateNodeArg("acc_out", &float_2x2); + body.AddNode("body_add", "Add", "acc + B", {&acc_in, &B_in_body}, {&acc_out}); + + auto& cond_out = body.GetOrCreateNodeArg("cond_out", &bool_scalar); + body.AddNode("body_cond_id", "Identity", "forward cond", {&cond_in}, {&cond_out}); + + body.SetInputs({&iter, &cond_in, &acc_in}); + body.SetOutputs({&cond_out, &acc_out}); + EXPECT_STATUS_OK(body.Resolve()); + return body.ToGraphProto(); + }; + + auto& M = main_graph.GetOrCreateNodeArg("M", &int64_scalar); + auto& cond_init = main_graph.GetOrCreateNodeArg("cond_init", &bool_scalar); + auto& acc_init = main_graph.GetOrCreateNodeArg("acc_init", &float_2x2); + auto& B = main_graph.GetOrCreateNodeArg("B", &float_2x2); + auto& v_final = main_graph.GetOrCreateNodeArg("v_final", &float_2x2); + + auto& loop_node = main_graph.AddNode( + "loop", "Loop", "Loop with outer-scope capture", + {&M, &cond_init, &acc_init}, {&v_final}); + loop_node.AddAttribute("body", build_body()); + + main_graph.SetInputs({&M, &cond_init, &acc_init, &B}); + main_graph.SetOutputs({&v_final}); + ASSERT_STATUS_OK(main_graph.Resolve()); + + GraphViewer graph_viewer(main_graph); + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + + const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; }; + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { return true; }; + const auto gen_metadef_name = [&]() { + static int id = 0; + return "TestMetaDef_loop_capture_" + std::to_string(id++); + }; + + auto result = utils::CreateSupportedPartitions( + graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, + &node_unit_map, /*drop_constant_initializers=*/true); + + ASSERT_EQ(result.size(), size_t(1)); + const auto* meta_def = result[0]->sub_graph->GetMetaDef(); + ASSERT_NE(meta_def, nullptr); + + const auto& inputs = meta_def->inputs; + + // Explicit Loop operands. + EXPECT_THAT(inputs, ::testing::Contains("M")); + EXPECT_THAT(inputs, ::testing::Contains("cond_init")); + EXPECT_THAT(inputs, ::testing::Contains("acc_init")); + // Outer-scope capture used only via ImplicitInputDefs; before the fix this + // was silently dropped from meta_def->inputs, leaving the fused node's + // InputDefs() unable to resolve B at Compute time. + EXPECT_THAT(inputs, ::testing::Contains("B")); + + const auto last_explicit = std::find(inputs.begin(), inputs.end(), "acc_init"); + const auto first_implicit = std::find(inputs.begin(), inputs.end(), "B"); + ASSERT_NE(last_explicit, inputs.end()); + ASSERT_NE(first_implicit, inputs.end()); + EXPECT_LT(last_explicit, first_implicit) + << "explicit Loop operands must precede implicit captures in meta_def->inputs"; +} + TEST(PartitioningUtilsTest, TestQDQNodeGroupWithRedundantClip) { const auto build_model = [](ModelTestBuilder& builder) { auto* input_0_arg = builder.MakeInput({2, 3, 3, 3}, std::numeric_limits::min(), diff --git a/onnxruntime/test/python/quantization/test_quant_issues.py b/onnxruntime/test/python/quantization/test_quant_issues.py index 91b60f31b1964..dcb4a524a01f4 100644 --- a/onnxruntime/test/python/quantization/test_quant_issues.py +++ b/onnxruntime/test/python/quantization/test_quant_issues.py @@ -119,6 +119,98 @@ def get_next(self): f"Expected quantized model at {output_path!r}", ) + def test_dynamic_quantize_per_channel_emits_axis_attribute(self): + """Per-channel dynamic quantization must emit axis on DequantizeLinear nodes. + + Regression test for https://github.com/microsoft/onnxruntime/issues/19997. + `quantize_dynamic(per_channel=True)` previously constructed QuantizedValue + with axis=None and built DequantizeLinear without an axis attribute, producing + an invalid per-tensor dequantization for per-channel quantized weights. + The quantizer encounters the unsupported `Identity` op consuming `weight` + and dequantizes the (now-quantized) per-channel weight initializer back to + float for it. That `_dequantize_value` call previously triggered an + assertion error (scale not scalar) and would have emitted a + DequantizeLinear lacking the required axis attribute. + """ + try: + import numpy as np # noqa: PLC0415 + import onnx # noqa: PLC0415 + from onnx import TensorProto, helper, numpy_helper # noqa: PLC0415 + + from onnxruntime.quantization import QuantType, quantize_dynamic # noqa: PLC0415 + except ImportError as exc: + raise unittest.SkipTest(f"Required import missing: {exc}") from exc + + # Build a model: input (5, 4) @ weight (4, 8) -> output (5, 8). + # The weight is also fed through Identity (an op the quantizer does not + # support); when the quantizer processes that Identity it dequantizes + # the per-channel-quantized weight initializer via _dequantize_value + # so the Identity input remains float. Exposing the Identity output as + # a graph output keeps the Identity reachable from the optimized graph. + # Weight axis=1 is the output-feature axis (per-channel quantization target). + np.random.seed(42) + weight_data = np.random.normal(0, 0.1, (4, 8)).astype(np.float32) + weight_init = numpy_helper.from_array(weight_data, name="weight") + + input_vi = helper.make_tensor_value_info("input", TensorProto.FLOAT, [5, 4]) + output_vi = helper.make_tensor_value_info("output", TensorProto.FLOAT, [5, 8]) + weight_out_vi = helper.make_tensor_value_info("weight_out", TensorProto.FLOAT, [4, 8]) + + matmul_node = helper.make_node("MatMul", ["input", "weight"], ["output"]) + identity_node = helper.make_node("Identity", ["weight"], ["weight_out"]) + + graph = helper.make_graph( + [matmul_node, identity_node], + "test_graph", + [input_vi], + [output_vi, weight_out_vi], + [weight_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + + with tempfile.TemporaryDirectory() as tmp: + model_fp_path = os.path.join(tmp, "model_fp.onnx") + model_q_path = os.path.join(tmp, "model_q.onnx") + onnx.save(model, model_fp_path) + + # This must not raise AssertionError due to per-channel scale not being scalar. + quantize_dynamic( + model_fp_path, + model_q_path, + per_channel=True, + weight_type=QuantType.QInt8, + ) + + q_model = onnx.load(model_q_path) + + # Find the DequantizeLinear node that dequantizes the weight initializer. + init_names = {init.name for init in q_model.graph.initializer} + dq_nodes = [n for n in q_model.graph.node if n.op_type == "DequantizeLinear"] + self.assertGreater(len(dq_nodes), 0, "Expected at least one DequantizeLinear node") + + weight_dq = None + for node in dq_nodes: + if node.input[0] in init_names: + weight_dq = node + break + self.assertIsNotNone(weight_dq, "No DequantizeLinear node found with a weight initializer as input") + + # The axis attribute must be present. + # MatMulInteger passes axis=-1 (last dimension) to quantize_weight_per_channel. + axis_attrs = [attr for attr in weight_dq.attribute if attr.name == "axis"] + self.assertEqual(len(axis_attrs), 1, "DequantizeLinear node is missing the 'axis' attribute") + # MatMulInteger quantizes weight with axis=-1 (default in __quantize_inputs). + self.assertEqual(axis_attrs[0].i, -1, f"Expected axis=-1, got axis={axis_attrs[0].i}") + + # The scale initializer must be 1-D with size > 1 (truly per-channel, not collapsed). + scale_name = weight_dq.input[1] + scale_init = next((i for i in q_model.graph.initializer if i.name == scale_name), None) + self.assertIsNotNone(scale_init, f"Scale initializer '{scale_name}' not found") + scale_array = numpy_helper.to_array(scale_init) + self.assertEqual(scale_array.ndim, 1, f"Expected 1-D scale, got shape {scale_array.shape}") + self.assertGreater(scale_array.size, 1, "Scale has only one element; expected per-channel scale") + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py new file mode 100644 index 0000000000000..77ac08cf50d6c --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Benchmark CPU GroupQueryAttention: Flash Attention vs Naive (full materialization). + +Runs the actual GQA operator via InferenceSession, toggling between flash and +naive paths using the ORT_GQA_DISABLE_FLASH_ATTENTION environment variable. + +Usage: + python benchmark_gqa_cpu_flash.py + python benchmark_gqa_cpu_flash.py --decode_only + python benchmark_gqa_cpu_flash.py --prompt_only +""" + +import argparse +import os +import time + +import numpy as np +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, SessionOptions + + +def create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with quantized KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + cache_ort_type = TensorProto.UINT8 if bit_width == 4 else TensorProto.INT8 + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "", + "", + "", + "", + "", # cos, sin, position_ids, attention_bias, head_sink + "k_scale", + "v_scale", + ] + while inputs and inputs[-1] == "": + inputs.pop() + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + k_quant_type=quant_type, + v_quant_type=quant_type, + kv_cache_bit_width=bit_width, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info( + "past_value", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, None), + helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, None), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info( + "present_value", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +def benchmark_gqa( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + past_seq_len=0, + warmup=5, + repeats=20, +): + """Benchmark a single GQA configuration. Returns elapsed time in ms.""" + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + + total_seqlen = past_seq_len + seq_len + buffer_seq_len = total_seqlen + + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + + sess_options = SessionOptions() + sess_options.intra_op_num_threads = 8 + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + # Generate inputs + np.random.seed(42) + query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) + key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) + total_seq = np.array([total_seqlen], dtype=np.int32) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } + + # Warmup + for _ in range(warmup): + sess.run(None, feeds) + + # Benchmark + start = time.perf_counter() + for _ in range(repeats): + sess.run(None, feeds) + elapsed_ms = (time.perf_counter() - start) / repeats * 1000.0 + + return elapsed_ms + + +def run_benchmarks(args): + """Run flash vs naive benchmarks for various configurations.""" + + configs = [] + + if not args.decode_only: + # Prefill configurations: seq_len = total_seqlen (prompt phase) + for total_seqlen in [512, 1024, 2048, 4096]: + configs.append( + { + "label": f"Prefill S={total_seqlen}", + "batch_size": 1, + "seq_len": total_seqlen, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": 0, + } + ) + + if not args.prompt_only: + # Decode configurations: seq_len=1, varying past + for past_seqlen in [512, 1024, 2048, 4096]: + configs.append( + { + "label": f"Decode T={past_seqlen + 1}", + "batch_size": 1, + "seq_len": 1, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": past_seqlen, + } + ) + + if not args.decode_only and not args.prompt_only: + # Batch decode + configs.append( + { + "label": "Decode B=4 T=2049", + "batch_size": 4, + "seq_len": 1, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": 2048, + } + ) + # INT4 prefill + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) + + warmup = args.warmup + repeats = args.repeats + + # Save and restore env var to avoid side effects on callers + saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + + print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") + print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") + print("-" * 62) + + for cfg in configs: + label = cfg.pop("label") + + # Flash path (default) + os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) + flash_ms = benchmark_gqa(**cfg, warmup=warmup, repeats=repeats) + + # Naive path (disabled flash) + os.environ["ORT_GQA_DISABLE_FLASH_ATTENTION"] = "1" + naive_ms = benchmark_gqa(**cfg, warmup=warmup, repeats=repeats) + + speedup = naive_ms / flash_ms if flash_ms > 0 else float("inf") + print(f"{label:<25} {naive_ms:>10.3f}ms {flash_ms:>10.3f}ms {speedup:>8.2f}x") + + # Restore original env state + if saved_env is not None: + os.environ["ORT_GQA_DISABLE_FLASH_ATTENTION"] = saved_env + else: + os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark GQA flash vs naive on CPU") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") + parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") + parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + args = parser.parse_args() + run_benchmarks(args) diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index c03545fc31435..99e669f73eb72 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -310,7 +310,13 @@ def make_bias_dropout_model(): def run_operator_test( - target_device, model_creator, inputs, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, session_config=None + target_device, + model_creator, + inputs, + expected_fn, + ep_name=CUDA_PLUGIN_EP_NAME, + session_config=None, + nhwc_ops=None, ): with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: model_path = tmp.name @@ -329,6 +335,10 @@ def run_operator_test( ) return False + # Structural assertion: verify NHWC domain assignment when requested + if nhwc_ops: + _assert_nhwc_domain_assigned(sess, ep_name, nhwc_ops) + print( f"(Session created with {active_providers}; assigned nodes: " f"{', '.join(_format_assigned_node(node) for node in assigned_nodes)})", @@ -407,6 +417,101 @@ def _expected_conv(inputs): _NHWC_CONFIG = {"ep.cuda.prefer_nhwc_layout": "1"} +def _assert_nhwc_domain_assigned(session, ep_name, expected_ops): + """Assert that NHWC layout transformation occurred for the expected ops. + + The framework's NHWC layout transformer rewrites eligible ops to the internal NHWC domain + and wraps them with Transpose nodes. We verify NHWC transformation by checking: + 1. If the assignment API surfaces NHWC-domain nodes, verify expected ops are present. + 2. Otherwise, fall back to checking that Transpose nodes were assigned (their presence + indicates the layout transformer ran and the NHWC kernel was found). + + Args: + session: An InferenceSession with graph assignment info enabled. + ep_name: Name of the EP to check (e.g., CUDA_PLUGIN_EP_NAME). + expected_ops: Set or list of op_type strings expected to have NHWC transformation. + + Returns: + True if evidence of NHWC transformation is found. Raises AssertionError otherwise. + """ + assigned_nodes, _ = _get_assigned_nodes(session, ep_name) + + # Check for NHWC-domain nodes directly (preferred when the API surfaces them). + nhwc_domain = "com.ms.internal.nhwc" + nhwc_ops_found = {n.op_type for n in assigned_nodes if n.domain == nhwc_domain} + if nhwc_ops_found: + missing = set(expected_ops) - nhwc_ops_found + if missing: + raise AssertionError( + f"Expected NHWC-domain nodes for {sorted(missing)} but only found " + f"{sorted(nhwc_ops_found)} in {ep_name} NHWC assignments." + ) + return True + + # Fallback: the NHWC transformation inserts Transpose nodes around the target op. + transpose_count = sum(1 for n in assigned_nodes if n.op_type == "Transpose") + if transpose_count == 0: + all_ops = [f"{n.domain or 'ai.onnx'}::{n.op_type}" for n in assigned_nodes] + raise AssertionError( + f"Expected NHWC layout transformation for {sorted(expected_ops)} but no Transpose " + f"nodes were found in {ep_name} assignments. Assigned ops: {all_ops}. " + f"This indicates the NHWC kernel was not found for the target op(s)." + ) + return True + + +def _run_nhwc_model_test(target_device, op_name, model, feed_dict, expected_fn, nhwc_ops=None, rtol=1e-3, atol=1e-3): + """Run an NHWC test: verify domain assignment and numerical correctness. + + Args: + target_device: EP device to test on. + op_name: Op type name (for display and default NHWC assertion). + model: ONNX model proto. + feed_dict: Input feed dictionary. + expected_fn: Function(feed_dict) -> expected output(s). + nhwc_ops: Set of op_types expected in NHWC domain (defaults to {op_name}). + rtol: Relative tolerance for output comparison. + atol: Absolute tolerance for output comparison. + + Returns: + TEST_PASS or TEST_FAIL string. + """ + if nhwc_ops is None: + nhwc_ops = {op_name} + with tempfile.NamedTemporaryFile(suffix=f"_{op_name}_nhwc.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = _create_session_options(_NHWC_CONFIG) + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + if not assigned_nodes: + print( + f"{TEST_FAIL} ({CUDA_PLUGIN_EP_NAME} was assigned no nodes; " + f"assignments={_format_assignment_summary(assignment_info)})" + ) + return TEST_FAIL + + # Structural assertion: verify NHWC domain assignment + _assert_nhwc_domain_assigned(sess, CUDA_PLUGIN_EP_NAME, nhwc_ops) + + res = sess.run(None, feed_dict) + expected = expected_fn(feed_dict) + if isinstance(expected, (list, tuple)): + for r, e in zip(res, expected, strict=True): + np.testing.assert_allclose(r, e, rtol=rtol, atol=atol) + else: + np.testing.assert_allclose(res[0], expected, rtol=rtol, atol=atol) + return TEST_PASS + except Exception as e: + print(f"{TEST_FAIL} ({e})") + return TEST_FAIL + finally: + if os.path.exists(model_path): + os.remove(model_path) + + def _expected_batchnorm(inputs): return inputs["X"] / np.sqrt(1.0 + 1e-5) @@ -589,7 +694,12 @@ def test_nhwc_conv(self): "W": np.random.rand(3, 2, 3, 3).astype(np.float32), } result = run_operator_test( - target_device, create_conv_model, inputs, _expected_conv, session_config=_NHWC_CONFIG + target_device, + create_conv_model, + inputs, + _expected_conv, + session_config=_NHWC_CONFIG, + nhwc_ops={"Conv"}, ) self.assertTrue(result, "Conv (NHWC) plugin test failed") @@ -597,7 +707,12 @@ def test_nhwc_batch_normalization(self): target_device = get_cuda_plugin_device() inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} result = run_operator_test( - target_device, create_batch_norm_model, inputs, _expected_batchnorm, session_config=_NHWC_CONFIG + target_device, + create_batch_norm_model, + inputs, + _expected_batchnorm, + session_config=_NHWC_CONFIG, + nhwc_ops={"BatchNormalization"}, ) self.assertTrue(result, "BatchNormalization (NHWC) plugin test failed") @@ -610,6 +725,7 @@ def test_nhwc_maxpool(self): inputs, lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), session_config=_NHWC_CONFIG, + nhwc_ops={"MaxPool"}, ) self.assertTrue(result, "MaxPool (NHWC) plugin test failed") @@ -622,9 +738,178 @@ def test_nhwc_avgpool(self): inputs, lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), session_config=_NHWC_CONFIG, + nhwc_ops={"AveragePool"}, ) self.assertTrue(result, "AveragePool (NHWC) plugin test failed") + def test_nhwc_conv_transpose(self): + target_device = get_cuda_plugin_device() + # ConvTranspose: input [1,2,4,4], weight [2,3,3,3] -> output [1,3,6,6] with stride=2, padding=1, output_padding=1 + f_dtype = TensorProto.FLOAT + node = helper.make_node( + "ConvTranspose", + ["X", "W"], + ["Y"], + strides=[2, 2], + pads=[1, 1, 1, 1], + output_padding=[1, 1], + group=1, + ) + graph = helper.make_graph( + [node], + "test-ConvTranspose", + [ + helper.make_tensor_value_info("X", f_dtype, [1, 2, 4, 4]), + helper.make_tensor_value_info("W", f_dtype, [2, 3, 3, 3]), + ], + [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 6, 6])], + ) + opset = OperatorSetIdProto() + opset.version = 11 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + w = np.random.rand(2, 3, 3, 3).astype(np.float32) + + def expected_fn(feed): + return F.conv_transpose2d( + torch.from_numpy(feed["X"]), + torch.from_numpy(feed["W"]), + stride=2, + padding=1, + output_padding=1, + ).numpy() + + result = _run_nhwc_model_test(target_device, "ConvTranspose", model, {"X": x, "W": w}, expected_fn) + self.assertEqual(result, TEST_PASS, "ConvTranspose (NHWC) plugin test failed") + + def test_nhwc_global_max_pool(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "GlobalMaxPool", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 1, 1])], + opset=12, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.adaptive_max_pool2d(t, output_size=1).numpy() + + result = _run_nhwc_model_test(target_device, "GlobalMaxPool", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "GlobalMaxPool (NHWC) plugin test failed") + + def test_nhwc_global_average_pool(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "GlobalAveragePool", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 1, 1])], + opset=12, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.adaptive_avg_pool2d(t, output_size=1).numpy() + + result = _run_nhwc_model_test(target_device, "GlobalAveragePool", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "GlobalAveragePool (NHWC) plugin test failed") + + def test_nhwc_depth_to_space(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # DepthToSpace: [1,4,2,2] -> [1,1,4,4] with blocksize=2 + model = _make_simple_model( + "DepthToSpace", + [("X", f_dtype, [1, 4, 2, 2])], + [("Y", f_dtype, [1, 1, 4, 4])], + attrs={"blocksize": 2, "mode": "DCR"}, + opset=13, + ) + x = np.random.rand(1, 4, 2, 2).astype(np.float32) + + def expected_fn(feed): + # DCR mode: depth, column, row + t = feed["X"] # [1, 4, 2, 2] + b = 2 + n, c, h, w = t.shape + t = t.reshape(n, b, b, c // (b * b), h, w) + t = t.transpose(0, 3, 4, 1, 5, 2) # [n, c/b^2, h, b, w, b] + return t.reshape(n, c // (b * b), h * b, w * b) + + result = _run_nhwc_model_test(target_device, "DepthToSpace", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "DepthToSpace (NHWC) plugin test failed") + + def test_nhwc_space_to_depth(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # SpaceToDepth: [1,1,4,4] -> [1,4,2,2] with blocksize=2 + model = _make_simple_model( + "SpaceToDepth", + [("X", f_dtype, [1, 1, 4, 4])], + [("Y", f_dtype, [1, 4, 2, 2])], + attrs={"blocksize": 2}, + opset=13, + ) + x = np.random.rand(1, 1, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = feed["X"] # [1, 1, 4, 4] + b = 2 + n, c, h, w = t.shape + t = t.reshape(n, c, h // b, b, w // b, b) + t = t.transpose(0, 3, 5, 1, 2, 4) # [n, b, b, c, h/b, w/b] + return t.reshape(n, c * b * b, h // b, w // b) + + result = _run_nhwc_model_test(target_device, "SpaceToDepth", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "SpaceToDepth (NHWC) plugin test failed") + + def test_nhwc_lrn(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # LRN: [1,3,4,4] with size=3, alpha=0.0001, beta=0.75, bias=1.0 + model = _make_simple_model( + "LRN", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 4, 4])], + attrs={"size": 3, "alpha": 0.0001, "beta": 0.75, "bias": 1.0}, + opset=13, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.local_response_norm(t, size=3, alpha=0.0001, beta=0.75, k=1.0).numpy() + + result = _run_nhwc_model_test(target_device, "LRN", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "LRN (NHWC) plugin test failed") + + def test_nhwc_grid_sample(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # GridSample: X [1,1,4,4], grid [1,3,3,2] -> Y [1,1,3,3] + model = _make_simple_model( + "GridSample", + [("X", f_dtype, [1, 1, 4, 4]), ("grid", f_dtype, [1, 3, 3, 2])], + [("Y", f_dtype, [1, 1, 3, 3])], + attrs={"mode": "linear", "padding_mode": "zeros", "align_corners": 0}, + opset=20, + ) + x = np.random.rand(1, 1, 4, 4).astype(np.float32) + # Grid values in [-1, 1] + grid = np.random.rand(1, 3, 3, 2).astype(np.float32) * 2 - 1 + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + g = torch.from_numpy(feed["grid"]) + return F.grid_sample(t, g, mode="bilinear", padding_mode="zeros", align_corners=False).numpy() + + result = _run_nhwc_model_test(target_device, "GridSample", model, {"X": x, "grid": grid}, expected_fn) + self.assertEqual(result, TEST_PASS, "GridSample (NHWC) plugin test failed") + # ---- Standard op tests ---- def test_op_reshape(self): diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py b/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py index 3224a07451534..4a4d3e6ff43e8 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py @@ -108,8 +108,10 @@ def dequantize_int4_per_channel(packed_uint8, scale, kv_num_heads, head_size): # ---- Reference attention ---- -def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, causal=True): - """Reference FP32 GQA: q[B,S,num_heads*H], k[B,N,S_kv,H], v[B,N,S_kv,H] -> out[B,S,num_heads*H].""" +def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, causal=True, attention_bias=None): + """Reference FP32 GQA: q[B,S,num_heads*H], k[B,N,S_kv,H], v[B,N,S_kv,H] -> out[B,S,num_heads*H]. + attention_bias: [B|1, num_heads|1, S, S_kv] or None. + """ batch, seq, _ = q_input.shape s_kv = k_input.shape[2] groups = num_heads // kv_num_heads @@ -128,6 +130,11 @@ def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, logits = np.zeros(s_kv, dtype=np.float32) for k_s in range(s_kv): logits[k_s] = np.dot(q_bnsh[b, h, q_s], k_input[b, kv_h, k_s]) * scale + # Attention bias + if attention_bias is not None: + bias_b = 0 if attention_bias.shape[0] == 1 else b + bias_h = 0 if attention_bias.shape[1] == 1 else h + logits[:s_kv] += attention_bias[bias_b, bias_h, q_s, :s_kv] # Causal mask if causal: for k_s in range(q_s + 1, s_kv): @@ -244,6 +251,103 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_quantized_gqa_graph_with_bias( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_batch_size, + bias_num_heads, + total_seqlen, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with quantized KV cache and attention bias.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + + cache_ort_type = TensorProto.UINT8 if bit_width == 4 else TensorProto.INT8 + + past_kv_seqlen = buffer_seq_len + present_kv_seqlen = buffer_seq_len + + # Inputs (attention_bias at index 10) + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "", # cos_cache + "", # sin_cache + "", # position_ids + "attention_bias", + "", # head_sink + "k_scale", + "v_scale", + ] + + # Remove trailing empty strings + while inputs and inputs[-1] == "": + inputs.pop() + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + k_quant_type=quant_type, + v_quant_type=quant_type, + kv_cache_bit_width=bit_width, + domain="com.microsoft", + ) + + # Graph inputs + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", cache_ort_type, [batch_size, kv_num_heads, past_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info( + "past_value", cache_ort_type, [batch_size, kv_num_heads, past_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + helper.make_tensor_value_info( + "attention_bias", TensorProto.FLOAT, [bias_batch_size, bias_num_heads, seq_len, total_seqlen] + ), + helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, None), + helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, None), + ] + + # Graph outputs + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", cache_ort_type, [batch_size, kv_num_heads, present_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info( + "present_value", cache_ort_type, [batch_size, kv_num_heads, present_kv_seqlen, packed_head_size] + ), + ] + + graph = helper.make_graph([node], "QuantizedGQA_Bias_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + # ---- Test runner ---- @@ -517,5 +621,253 @@ def test_int4_long_sequence(self): ) +def run_quantized_gqa_bias_test( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_broadcast_batch=False, + bias_broadcast_head=False, + atol=None, +): + """Run a quantized GQA test with attention bias and compare against reference.""" + np.random.seed(123) + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) + key_input = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + value_input = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + + # Reshape K/V to BNSH for quantization + k_bnsh = key_input.reshape(batch_size, seq_len, kv_num_heads, head_size).transpose(0, 2, 1, 3) + v_bnsh = value_input.reshape(batch_size, seq_len, kv_num_heads, head_size).transpose(0, 2, 1, 3) + + # Compute scales + if bit_width == 8: + if quant_type == "PER_TENSOR": + _, k_scale = quantize_int8_per_tensor(k_bnsh) + _, v_scale = quantize_int8_per_tensor(v_bnsh) + else: + _, k_scale = quantize_int8_per_channel(k_bnsh) + _, v_scale = quantize_int8_per_channel(v_bnsh) + else: + if quant_type == "PER_TENSOR": + _, k_scale = quantize_int4_per_tensor(k_bnsh) + _, v_scale = quantize_int4_per_tensor(v_bnsh) + else: + _, k_scale = quantize_int4_per_channel(k_bnsh) + _, v_scale = quantize_int4_per_channel(v_bnsh) + + # Empty past (prompt) + packed_head_size = head_size // 2 if bit_width == 4 else head_size + if bit_width == 4: + past_k = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.uint8) + past_v = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.uint8) + else: + past_k = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.int8) + past_v = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.int8) + + seqlens_k = np.array([seq_len - 1] * batch_size, dtype=np.int32) + total_seq = np.array([seq_len], dtype=np.int32) + + # Generate attention bias + bias_batch = 1 if bias_broadcast_batch else batch_size + bias_heads = 1 if bias_broadcast_head else num_heads + attention_bias = np.random.uniform(-1.0, 1.0, (bias_batch, bias_heads, seq_len, seq_len)).astype(np.float32) + + # Build and run ONNX model + onnx_model_str = create_quantized_gqa_graph_with_bias( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_batch_size=bias_batch, + bias_num_heads=bias_heads, + total_seqlen=seq_len, + ) + sess_options = SessionOptions() + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + feeds = { + "query": query, + "key": key_input, + "value": value_input, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "attention_bias": attention_bias, + "k_scale": k_scale, + "v_scale": v_scale, + } + + outputs = sess.run(None, feeds) + out_ort = outputs[0] + + # Compute reference with quantized K/V + if bit_width == 8 and quant_type == "PER_TENSOR": + k_q = np.clip(np.round(k_bnsh / k_scale[0]), -128, 127).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale[0]), -128, 127).astype(np.int8) + k_deq = dequantize_int8_per_tensor(k_q, k_scale[0]) + v_deq = dequantize_int8_per_tensor(v_q, v_scale[0]) + elif bit_width == 8 and quant_type == "PER_CHANNEL": + k_q = np.clip(np.round(k_bnsh / k_scale.reshape(1, kv_num_heads, 1, head_size)), -128, 127).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale.reshape(1, kv_num_heads, 1, head_size)), -128, 127).astype(np.int8) + k_deq = dequantize_int8_per_channel(k_q, k_scale, kv_num_heads, head_size) + v_deq = dequantize_int8_per_channel(v_q, v_scale, kv_num_heads, head_size) + elif bit_width == 4 and quant_type == "PER_TENSOR": + k_q = np.clip(np.round(k_bnsh / k_scale[0]), -8, 7).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale[0]), -8, 7).astype(np.int8) + k_deq = k_q.astype(np.float32) * k_scale[0] + v_deq = v_q.astype(np.float32) * v_scale[0] + elif bit_width == 4 and quant_type == "PER_CHANNEL": + k_q = np.clip(np.round(k_bnsh / k_scale.reshape(1, kv_num_heads, 1, head_size)), -8, 7).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale.reshape(1, kv_num_heads, 1, head_size)), -8, 7).astype(np.int8) + k_deq = k_q.astype(np.float32) * k_scale.reshape(1, kv_num_heads, 1, head_size) + v_deq = v_q.astype(np.float32) * v_scale.reshape(1, kv_num_heads, 1, head_size) + else: + raise ValueError(f"Unsupported config: bit_width={bit_width}, quant_type={quant_type}") + + out_ref = reference_gqa( + query, k_deq, v_deq, num_heads, kv_num_heads, head_size, causal=True, attention_bias=attention_bias + ) + + if atol is None: + atol = 0.15 if bit_width == 4 else 0.05 + + if np.any(np.isnan(out_ort)): + raise AssertionError(f"NaN in output (quant={quant_type}, bit={bit_width}, bias test)") + if np.allclose(out_ort, 0.0): + raise AssertionError(f"Output is all zeros (quant={quant_type}, bit={bit_width}, bias test)") + + np.testing.assert_allclose( + out_ort, + out_ref, + atol=atol, + rtol=0.1, + err_msg=f"Quantized GQA + bias mismatch (quant={quant_type}, bit={bit_width})", + ) + + +class TestGQACPUQuantizedKVWithBias(unittest.TestCase): + """Test CPU GroupQueryAttention with quantized KV cache and attention bias.""" + + def test_int8_per_tensor_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + ) + + def test_int8_per_channel_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_CHANNEL", + bit_width=8, + ) + + def test_int4_per_tensor_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_TENSOR", + bit_width=4, + ) + + def test_int4_per_channel_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_CHANNEL", + bit_width=4, + ) + + def test_int8_bias_broadcast_batch(self): + """Bias shape [1, N, S, T] with batch_size > 1.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_batch=True, + ) + + def test_int8_bias_broadcast_head(self): + """Bias shape [B, 1, S, T] with num_heads > 1.""" + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_head=True, + ) + + def test_int8_bias_broadcast_both(self): + """Bias shape [1, 1, S, T] with batch_size > 1 and num_heads > 1.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_batch=True, + bias_broadcast_head=True, + ) + + def test_int8_bias_large(self): + """Larger test to exercise flash attention path with bias.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=32, + num_heads=4, + kv_num_heads=2, + head_size=64, + quant_type="PER_TENSOR", + bit_width=8, + ) + + def test_int4_bias_large(self): + """Larger test with INT4 to exercise flash attention path with bias.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=32, + num_heads=4, + kv_num_heads=2, + head_size=64, + quant_type="PER_CHANNEL", + bit_width=4, + ) + + if __name__ == "__main__": unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 2d7895510afeb..bd5290e8f792c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -183,23 +183,23 @@ stages: displayName: 'Shell Script' inputs: scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' - args: '-a $(Build.BinariesDirectory)/tgz-artifacts' + args: '-a $(Build.BinariesDirectory)/tgz-artifacts -c $(CUDA_VERSION_MAJOR)' workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' - task: ArchiveFiles@2 inputs: - rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' + rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' includeRootFolder: false archiveType: 'tar' # Options: zip, 7z, tar, wim tarCompression: 'gz' - archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).tgz' replaceExistingArchive: true - template: ../templates/validate-package.yml parameters: PackageType: 'tarball' PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + PackageName: 'onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).tgz' ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' PlatformsSupported: 'linux-x64' VerifyNugetSigning: false @@ -214,10 +214,12 @@ stages: script: | docker run -e SYSTEM_COLLECTIONURI --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/models:/data/models --volume $(Build.SourcesDirectory):/src_dir \ --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build \ - /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet + /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu_cuda${CUDA_VERSION_MAJOR}-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet workingDirectory: '$(Build.ArtifactStagingDirectory)' + env: + CUDA_VERSION_MAJOR: $(CUDA_VERSION_MAJOR) - task: 1ES.PublishPipelineArtifact@1 inputs: targetPath: '$(Build.ArtifactStagingDirectory)' - artifactName: 'onnxruntime-linux-x64-gpu' + artifactName: 'onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index b072e22818eec..0e73dff34aa6a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -122,6 +122,10 @@ stages: variables: CUDA_MODULE_LOADINGL: 'LAZY' GRADLE_OPTS: '-Dorg.gradle.daemon=false' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + CUDA_VERSION_MAJOR: '13' + ${{ if eq(parameters.CudaVersion, '12.8') }}: + CUDA_VERSION_MAJOR: '12' steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples @@ -181,14 +185,14 @@ stages: displayName: 'Copy zip file to: $(Build.ArtifactStagingDirectory)' inputs: SourceFolder: '$(Build.BinariesDirectory)\zip-artifacts' - Contents: 'onnxruntime-win-x64-gpu-*.zip' + Contents: 'onnxruntime-win-x64-gpu_cuda*-*.zip' TargetFolder: '$(Build.ArtifactStagingDirectory)' - template: ../templates/validate-package.yml parameters: PackageType: 'zip' PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' + PackageName: 'onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).zip' ScriptPath: '$(Build.SourcesDirectory)\onnxruntime\tools\nuget\validate_package.py' PlatformsSupported: 'win-x64' VerifyNugetSigning: false @@ -200,11 +204,11 @@ stages: condition: and(succeeded(), ne(${{parameters.CudaVersion}}, '13.0')) inputs: filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat - arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet + arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet workingFolder: '$(Build.ArtifactStagingDirectory)' - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Pipeline Combined GPU Package Artifact' inputs: - artifactName: 'onnxruntime-win-x64-gpu' + artifactName: 'onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' targetPath: '$(Build.ArtifactStagingDirectory)' diff --git a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh index 04ac0e35a6d78..998e71c20539c 100755 --- a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh +++ b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh @@ -1,30 +1,36 @@ #!/bin/bash set -e -o -x -while getopts a: parameter_Option +while getopts a:c: parameter_Option do case "${parameter_Option}" in a) ARTIFACT_DIR=${OPTARG};; +c) CUDA_MAJOR=${OPTARG};; +*) echo "Unknown option"; exit 1;; esac done -EXIT_CODE=1 +if [ -z "$CUDA_MAJOR" ]; then + echo "Error: CUDA major version (-c) is required" + exit 1 +fi uname -a -cd $ARTIFACT_DIR +cd "$ARTIFACT_DIR" -mkdir -p $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt -tar zxvf $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt-*.tgz -C onnxruntime-linux-x64-tensorrt -rm $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt-*.tgz +mkdir -p "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt +tar zxvf "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt-*.tgz -C onnxruntime-linux-x64-tensorrt +rm "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt-*.tgz -# Rename cuda directory to gpu directory -mkdir -p $ARTIFACT_DIR/onnxruntime-linux-x64-gpu -tar zxvf $ARTIFACT_DIR/onnxruntime-linux-x64-cuda-*.tgz -C onnxruntime-linux-x64-gpu -VERSION=`ls $ARTIFACT_DIR/onnxruntime-linux-x64-gpu | sed 's/onnxruntime-linux-x64-cuda-//'` -mv $ARTIFACT_DIR/onnxruntime-linux-x64-gpu/* $ARTIFACT_DIR/onnxruntime-linux-x64-gpu/onnxruntime-linux-x64-gpu-$VERSION -rm $ARTIFACT_DIR/onnxruntime-linux-x64-cuda-*.tgz +# Rename cuda directory to gpu_cuda{MAJOR} directory +GPU_DIR_NAME="onnxruntime-linux-x64-gpu_cuda${CUDA_MAJOR}" +mkdir -p "$ARTIFACT_DIR"/"$GPU_DIR_NAME" +tar zxvf "$ARTIFACT_DIR"/onnxruntime-linux-x64-cuda-*.tgz -C "$GPU_DIR_NAME" +VERSION=$(find "$ARTIFACT_DIR"/"$GPU_DIR_NAME" -maxdepth 1 -mindepth 1 -printf '%f\n' | sed 's/onnxruntime-linux-x64-cuda-//') +mv "$ARTIFACT_DIR"/"$GPU_DIR_NAME"/* "$ARTIFACT_DIR"/"$GPU_DIR_NAME"/"${GPU_DIR_NAME}-${VERSION}" +rm "$ARTIFACT_DIR"/onnxruntime-linux-x64-cuda-*.tgz -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime.so* onnxruntime-linux-x64-gpu/*/lib -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_tensorrt.so onnxruntime-linux-x64-gpu/*/lib -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_shared.so onnxruntime-linux-x64-gpu/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime.so* "$GPU_DIR_NAME"/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_tensorrt.so "$GPU_DIR_NAME"/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_shared.so "$GPU_DIR_NAME"/*/lib diff --git a/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 b/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 index 6671fecfbe072..0e082bbdde531 100644 --- a/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 @@ -1,8 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +$CudaMajor = $Env:CUDA_VERSION_MAJOR +if (-not $CudaMajor) { + Write-Error "CUDA_VERSION_MAJOR environment variable is required" + exit 1 +} + # extract *-cuda-*.zip and *-tensorrt-*.zip -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *.zip | +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *.zip | Foreach-Object { $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\zip-artifacts" Write-Output $cmd @@ -13,13 +19,14 @@ Foreach-Object { Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts | Where-Object { $_.Name -match 'onnxruntime-win-x64-tensorrt-\d{1,}\.\d{1,}\.\d{1,}$' } | Rename-Item -NewName $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-tensorrt Remove-Item $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\*.zip -# Rename cuda directory to gpu directory and re-compress it for later use in bundle_dlls_gpu.bat +# Rename cuda directory to gpu_cuda{MAJOR} directory and re-compress it for later use in bundle_dlls_gpu.bat Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *cuda* | Foreach-Object { $($_.FullName) -match '.*onnxruntime-win-x64-cuda-(.*)' $version=$matches[1] - Rename-Item -Path $($_.FullName) -NewName onnxruntime-win-x64-gpu-$version - $cmd = "7z.exe a $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-gpu-$version.zip $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-gpu-$version" + $gpuName = "onnxruntime-win-x64-gpu_cuda${CudaMajor}-$version" + Rename-Item -Path $($_.FullName) -NewName $gpuName + $cmd = "7z.exe a $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\${gpuName}.zip $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\${gpuName}" Write-Output $cmd Invoke-Expression -Command $cmd } diff --git a/tools/nuget/validate_package.py b/tools/nuget/validate_package.py index 59e88ea15e7c6..44951c9c3194f 100644 --- a/tools/nuget/validate_package.py +++ b/tools/nuget/validate_package.py @@ -232,10 +232,7 @@ def validate_tarball(args): raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name - if "-gpu-" in package_name.lower(): - is_gpu_package = True - else: - is_gpu_package = False + is_gpu_package = "-gpu_cuda" in package_name.lower() package_folder = re.search("(.*)[.].*", package_name).group(1) @@ -266,10 +263,7 @@ def validate_zip(args): raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name - if "-gpu-" in package_name.lower(): - is_gpu_package = True - else: - is_gpu_package = False + is_gpu_package = "-gpu_cuda" in package_name.lower() package_folder = re.search("(.*)[.].*", package_name).group(1) diff --git a/tools/python/compile_contributors.py b/tools/python/compile_contributors.py index bb02c2807d08c..92ba59747493e 100644 --- a/tools/python/compile_contributors.py +++ b/tools/python/compile_contributors.py @@ -11,10 +11,16 @@ Usage: python compile_contributors.py [--base ] [--target ] [--dir ] + [--paths [ ...]] Example: python compile_contributors.py --base origin/rel-1.23.2 --target origin/rel-1.24.1 --dir rel-1.24.1_report + # Limit to commits that touch selected areas (replace with your paths): + # Using git pathspec syntax, ":(top)" anchors each path at repository root. + python compile_contributors.py --base origin/main~500 --target origin/main \ + --paths ":(top)path/to/component_a" ":(top)path/to/component_b" + Outputs: - detail.csv: Detailed breakdown of PRs, authors, and commit links. - logs.txt: Processing logs and summary (professional humans-only contributor list for release notes). @@ -314,6 +320,19 @@ def main(): parser.add_argument("--target", default="origin/rel-1.24.1", help="Target branch/commit to compare to") parser.add_argument("--dir", default="contributors", help="Output directory for reports and logs") parser.add_argument("--scan-depth", type=int, default=200, help="Depth to scan base/meta-PRs for deduplication") + parser.add_argument( + "--paths", + nargs="+", + default=None, + metavar="PATH", + help=( + "Optional list of paths (git pathspec) to limit history to. " + "Only commits that touch one of these paths are considered. " + "Note: when a 'Cherry-pick round' meta-PR is included because at " + "least one of its cherry-picks touched these paths, all its " + "sub-PRs are still expanded regardless of paths." + ), + ) args = parser.parse_args() # Early validation @@ -324,6 +343,9 @@ def main(): branch_target = args.target output_dir = args.dir scan_depth = args.scan_depth + # Build a pathspec suffix (e.g. ["--", "onnxruntime/core/providers/webgpu", ...]) once, + # so it can be appended to each `git log` invocation below. + paths_args = (["--", *args.paths]) if args.paths else [] if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -331,10 +353,12 @@ def main(): logs_path = os.path.join(output_dir, "logs.txt") with open(logs_path, "w", encoding="utf-8") as log_file: log_event(f"Starting comparison: {branch_base} -> {branch_target}", log_file) + if args.paths: + log_event(f"Limiting history to paths: {args.paths}", log_file) # 1. Fetch base branch PRs (scan depth controlled by scan_depth) log_event(f"Fetching base branch history for {branch_base} (last {scan_depth})...", log_file) - log_base = run_command(["git", "log", branch_base, "-n", str(scan_depth), "--oneline"]) + log_base = run_command(["git", "log", branch_base, "-n", str(scan_depth), "--oneline", *paths_args]) if log_base is None: log_event( f"Error: Could not fetch history for base ref '{branch_base}'. Please check if the ref exists.", @@ -348,7 +372,7 @@ def main(): # 2. Fetch target branch PRs (only those not in base) log_event(f"Fetching target branch history: {branch_base}..{branch_target}...", log_file) # Using A..B syntax for git log - log_target = run_command(["git", "log", f"{branch_base}..{branch_target}", "--oneline"]) + log_target = run_command(["git", "log", f"{branch_base}..{branch_target}", "--oneline", *paths_args]) if log_target is None: log_event( f"Error: Could not fetch history for range '{branch_base}..{branch_target}'. Please check if the refs exist.",