From 4b7f8f1789d35bdd2a365deff1c4bd98fdf2add4 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 28 May 2026 20:34:10 +0000 Subject: [PATCH 01/10] [WebGPU EP] Add environment variable to dump shader code to a file, move shader key validation to nightly build (#28674) ### Description Allow shader code to be dumped to the file specified in the `ORT_WEBGPU_EP_SHADER_DUMP_FILE` environment variable. Previously, shader code was only dumped by verbose logging. Create new nightly CI pipeline to run shader key validation test. That test is removed from the CI pipeline in #28642. ### Motivation and Context More shader dump output options. Moving shader key validation test. --- .github/workflows/nightly_webgpu.yml | 77 +++++++++++++++++++ .../core/providers/webgpu/program_manager.cc | 45 +++++++---- .../core/providers/webgpu/program_manager.h | 8 +- .../core/providers/webgpu/webgpu_context.cc | 2 - 4 files changed, 114 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/nightly_webgpu.yml diff --git a/.github/workflows/nightly_webgpu.yml b/.github/workflows/nightly_webgpu.yml new file mode 100644 index 0000000000000..b3da29a2f0bd4 --- /dev/null +++ b/.github/workflows/nightly_webgpu.yml @@ -0,0 +1,77 @@ +name: Nightly ONNX Runtime WebGPU Builds + +on: + schedule: + - cron: '0 9 * * *' # Daily at 09:00 UTC + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + webgpu_shader_key_validation: + runs-on: [ + "self-hosted", + "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "JobId=webgpu_shader_validation-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" + ] + timeout-minutes: 90 + env: + ALLOW_RELEASED_ONNX_OPSET_ONLY: "0" + ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + submodules: none + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r tools\ci_build\github\windows\python\requirements.txt + shell: cmd + working-directory: ${{ github.workspace }} + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: "24" + + - name: Build and Test + shell: pwsh + run: | + $env:ORT_WEBGPU_EP_SHADER_DUMP_FILE = "${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log" + + python.exe ${{ github.workspace }}\tools\ci_build\build.py ` + --config RelWithDebInfo ` + --build_dir ${{ github.workspace }} ` + --use_binskim_compliant_compile_flags ` + --cmake_generator "Visual Studio 17 2022" ` + --build_shared_lib ` + --use_webgpu ` + --wgsl_template static ` + --cmake_extra_defines onnxruntime_BUILD_DAWN_SHARED_LIBRARY=ON ` + --update ` + --build --parallel ` + --test + + - name: Check log file + shell: cmd + run: | + dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log + + - name: Validate shader keys + uses: ./.github/actions/webgpu-validate-shader-key + with: + log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\shader_dump.log diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index e4376476a885d..136e7d503f59f 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -2,9 +2,12 @@ // Licensed under the MIT License. #include +#include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" +#include "core/platform/env_var.h" #include "core/providers/webgpu/program_manager.h" #include "core/providers/webgpu/shader_helper.h" @@ -18,6 +21,17 @@ ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeli compute_pipeline{compute_pipeline}, shape_uniform_ranks{shape_uniform_ranks} {} +ProgramManager::ProgramManager(WebGpuContext& webgpu_context) + : webgpu_context_{webgpu_context} { + if (std::string dump_file_path = onnxruntime::detail::GetEnvironmentVar("ORT_WEBGPU_EP_SHADER_DUMP_FILE"); + !dump_file_path.empty()) { + auto dump_file = std::make_shared(dump_file_path.c_str(), std::ios::app); + shader_dump_fn_ = [dump_file = std::move(dump_file)](std::string_view shader_content) { + *dump_file << shader_content << "\n"; + }; + } +} + Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -66,9 +80,7 @@ Status ProgramManager::Build(const ProgramBase& program, const ProgramMetadata& program_metadata, const std::span inputs_segments, const std::span outputs_segments, -#ifndef NDEBUG // if debug build const std::string& program_key, -#endif uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, @@ -100,17 +112,24 @@ Status ProgramManager::Build(const ProgramBase& program, std::string code; ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); - LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() -#ifndef NDEBUG // if debug build - << ", Key=\"" << program_key << "\"" -#endif - << "] Start ===\n\n" - << code - << "\n=== WebGPU Shader code [" << program.Name() -#ifndef NDEBUG // if debug build - << ", Key=\"" << program_key << "\"" -#endif - << "] End ===\n"; + // Dump shader code, if requested. It is dumped to `shader_dump_fn_` if set or VERBOSE logging otherwise. + { + const auto shader_content = [&program, &program_key, &code]() { + return MakeString("\n=== WebGPU Shader code [", program.Name(), + ", Key=\"", program_key, "\"", + "] Start ===\n\n", + code, + "\n=== WebGPU Shader code [", program.Name(), + ", Key=\"", program_key, "\"", + "] End ===\n"); + }; + + if (shader_dump_fn_) { + shader_dump_fn_(shader_content()); + } else { + LOGS_DEFAULT(VERBOSE) << shader_content(); + } + } wgpu::ShaderSourceWGSL wgsl_source{}; wgsl_source.code = code.c_str(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 5c4f76d0b4168..afdffe94ea30a 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -3,8 +3,10 @@ #pragma once +#include #include #include +#include #include #include "core/providers/webgpu/webgpu_external_header.h" @@ -36,7 +38,7 @@ class ProgramArtifact { class ProgramManager { public: - ProgramManager(WebGpuContext& webgpu_context) : webgpu_context_(webgpu_context) {} + ProgramManager(WebGpuContext& webgpu_context); Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; Status CalculateSegmentsForInputsAndOutputs(const ProgramBase& program, std::vector& inputs_segments, std::vector& outputs_segments) const; @@ -45,9 +47,7 @@ class ProgramManager { const ProgramMetadata& metadata, const std::span inputs_segments, const std::span outputs_segments, -#ifndef NDEBUG // if debug build const std::string& program_key, -#endif uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, @@ -59,6 +59,8 @@ class ProgramManager { private: std::unordered_map programs_; WebGpuContext& webgpu_context_; + + std::function shader_dump_fn_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ada9a2e8ab692..c7750198ceebc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -303,9 +303,7 @@ Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& progra metadata, inputs_segments, outputs_segments, -#ifndef NDEBUG // if debug build key, -#endif x, y, z, From f1b130529a20c8a053fd5593ecd5f6294694d2a7 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 28 May 2026 21:14:48 +0000 Subject: [PATCH 02/10] Replace direct inclusion of with "core/providers/cuda/cu_inc/cub.cuh" wrapper. (#28705) ### Description Replace direct inclusion of `` with `"core/providers/cuda/cu_inc/cub.cuh"` wrapper. The wrapper accounts for a problematic macro definition which causes issues. ### Motivation and Context Fix pipeline build error. --- onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu | 2 +- onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index f9d949012e64c..ab0e2d9e01901 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -49,6 +49,7 @@ #include "core/common/common.h" #include "core/common/safeint.h" +#include "core/providers/cuda/cu_inc/cub.cuh" #include "contrib_ops/cuda/llm/common/logger.h" #include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" #include "contrib_ops/cuda/llm/common/data_type.h" @@ -63,7 +64,6 @@ #include "contrib_ops/cuda/llm/moe_gemm/moe_gemm_activation_kernels.cuh" #include "contrib_ops/cuda/llm/moe_gemm/moe_gemm_utils.cuh" -#include #include #include diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu index 28fd4fb1516fb..61cdf3ab23fca 100644 --- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu @@ -5,9 +5,9 @@ #include "contrib_ops/cuda/moe/qmoe_kernels.h" #include "core/common/narrow.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/cub.cuh" #include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h" #include -#include #include #include From b668e4766bbcedf2fcfda42c43a4f75e72c3bdc3 Mon Sep 17 00:00:00 2001 From: Rishi Dave <62260675+Rishi-Dave@users.noreply.github.com> Date: Thu, 28 May 2026 15:18:08 -0700 Subject: [PATCH 03/10] fix(quantization): emit axis on DequantizeLinear for per-channel dynamic quantization (#28228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fix `quantize_dynamic(per_channel=True)` so weights quantized per-channel produce a `DequantizeLinear` node with the correct `axis` attribute. - Stop dropping the channel axis when `quantize_weight_per_channel` populates `QuantizedValue` (was hardcoded to `None`). - Gate the scalar-scale assertion in `_dequantize_value` on `axis is None` so per-channel scales (1-D tensors) are accepted. ## Motivation Fixes #19997. When a model is quantized with `quantize_dynamic(..., per_channel=True)` and a per-channel weight reaches `_dequantize_value` (e.g. via `_dequantize_outputs` when the weight is in the graph outputs), two bugs surface: 1. `quantize_weight_per_channel` stores `QuantizedValue.axis = None` even though it received a real `channel_axis`, so the per-channel information is lost. 2. `_dequantize_value` (a) asserts `scale_init.size == 1`, which fails for a 1-D per-channel scale, and (b) builds the `DequantizeLinear` node without an `axis` attribute, producing an invalid ONNX node when the model is consumed. PR #22283 (Nov 2024) softened the assertion against `None`-typed scales but left the underlying axis-propagation bug in place. ## Changes - `onnxruntime/python/tools/quantization/onnx_quantizer.py` - `quantize_weight_per_channel`: pass `channel_axis` (was `None`) into `QuantizedValue`. - `_dequantize_value`: only require a scalar scale on the per-tensor path (`axis is None`); forward `axis=quantized_value.axis` to `onnx.helper.make_node("DequantizeLinear", ...)`. `make_node` silently omits the attribute when `axis` is `None`, so the per-tensor path is unchanged. - `onnxruntime/test/python/quantization/test_quant_issues.py` - New regression test `test_dynamic_quantize_per_channel_emits_axis_attribute` that builds a minimal MatMul model with the weight routed to a graph output (to force the `_dequantize_outputs` -> `_dequantize_value` path), runs `quantize_dynamic(per_channel=True)`, and asserts the emitted `DequantizeLinear` has the `axis` attribute and a 1-D multi-element scale initializer. ## Test Plan - `python -m pytest onnxruntime/test/python/quantization/test_quant_issues.py -xvs` — new test passes; existing test skipped as before. - `python -m pytest onnxruntime/test/python/quantization/test_op_matmul.py` — 7 passed, 8 skipped (no regression). - `python -m pytest onnxruntime/test/python/quantization/test_qdq.py -k per_channel` — 1 passed. - `lintrunner -a` on changed files: clean. --- .../tools/quantization/onnx_quantizer.py | 13 ++- .../python/quantization/test_quant_issues.py | 92 +++++++++++++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index ac8dbfe8f8348..d7c01c2ab8a2d 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -1072,7 +1072,7 @@ def quantize_weight_per_channel( scale_name, zp_name, QuantizedValueType.Initializer, - None, + channel_axis, ) self.quantized_value_map[weight_name] = quantized_value @@ -1097,8 +1097,9 @@ def _dequantize_value(self, value_name): if self.model.model.producer_name != "onnx-quantizer" or ( self.model.model.producer_name == "onnx-quantizer" and scale_init is not None ): - # axis is not specified so scale_init must be a scalar. - assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 + # Per-tensor (axis=None) requires a scalar scale. + if quantized_value.axis is None: + assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) @@ -1109,7 +1110,11 @@ def _dequantize_value(self, value_name): quantized_value.zp_name, ] dequantize_node = onnx.helper.make_node( - "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name + "DequantizeLinear", + dqlinear_inputs, + [value_name], + dqlinear_name, + axis=quantized_value.axis, ) return dequantize_node else: diff --git a/onnxruntime/test/python/quantization/test_quant_issues.py b/onnxruntime/test/python/quantization/test_quant_issues.py index 91b60f31b1964..dcb4a524a01f4 100644 --- a/onnxruntime/test/python/quantization/test_quant_issues.py +++ b/onnxruntime/test/python/quantization/test_quant_issues.py @@ -119,6 +119,98 @@ def get_next(self): f"Expected quantized model at {output_path!r}", ) + def test_dynamic_quantize_per_channel_emits_axis_attribute(self): + """Per-channel dynamic quantization must emit axis on DequantizeLinear nodes. + + Regression test for https://github.com/microsoft/onnxruntime/issues/19997. + `quantize_dynamic(per_channel=True)` previously constructed QuantizedValue + with axis=None and built DequantizeLinear without an axis attribute, producing + an invalid per-tensor dequantization for per-channel quantized weights. + The quantizer encounters the unsupported `Identity` op consuming `weight` + and dequantizes the (now-quantized) per-channel weight initializer back to + float for it. That `_dequantize_value` call previously triggered an + assertion error (scale not scalar) and would have emitted a + DequantizeLinear lacking the required axis attribute. + """ + try: + import numpy as np # noqa: PLC0415 + import onnx # noqa: PLC0415 + from onnx import TensorProto, helper, numpy_helper # noqa: PLC0415 + + from onnxruntime.quantization import QuantType, quantize_dynamic # noqa: PLC0415 + except ImportError as exc: + raise unittest.SkipTest(f"Required import missing: {exc}") from exc + + # Build a model: input (5, 4) @ weight (4, 8) -> output (5, 8). + # The weight is also fed through Identity (an op the quantizer does not + # support); when the quantizer processes that Identity it dequantizes + # the per-channel-quantized weight initializer via _dequantize_value + # so the Identity input remains float. Exposing the Identity output as + # a graph output keeps the Identity reachable from the optimized graph. + # Weight axis=1 is the output-feature axis (per-channel quantization target). + np.random.seed(42) + weight_data = np.random.normal(0, 0.1, (4, 8)).astype(np.float32) + weight_init = numpy_helper.from_array(weight_data, name="weight") + + input_vi = helper.make_tensor_value_info("input", TensorProto.FLOAT, [5, 4]) + output_vi = helper.make_tensor_value_info("output", TensorProto.FLOAT, [5, 8]) + weight_out_vi = helper.make_tensor_value_info("weight_out", TensorProto.FLOAT, [4, 8]) + + matmul_node = helper.make_node("MatMul", ["input", "weight"], ["output"]) + identity_node = helper.make_node("Identity", ["weight"], ["weight_out"]) + + graph = helper.make_graph( + [matmul_node, identity_node], + "test_graph", + [input_vi], + [output_vi, weight_out_vi], + [weight_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + + with tempfile.TemporaryDirectory() as tmp: + model_fp_path = os.path.join(tmp, "model_fp.onnx") + model_q_path = os.path.join(tmp, "model_q.onnx") + onnx.save(model, model_fp_path) + + # This must not raise AssertionError due to per-channel scale not being scalar. + quantize_dynamic( + model_fp_path, + model_q_path, + per_channel=True, + weight_type=QuantType.QInt8, + ) + + q_model = onnx.load(model_q_path) + + # Find the DequantizeLinear node that dequantizes the weight initializer. + init_names = {init.name for init in q_model.graph.initializer} + dq_nodes = [n for n in q_model.graph.node if n.op_type == "DequantizeLinear"] + self.assertGreater(len(dq_nodes), 0, "Expected at least one DequantizeLinear node") + + weight_dq = None + for node in dq_nodes: + if node.input[0] in init_names: + weight_dq = node + break + self.assertIsNotNone(weight_dq, "No DequantizeLinear node found with a weight initializer as input") + + # The axis attribute must be present. + # MatMulInteger passes axis=-1 (last dimension) to quantize_weight_per_channel. + axis_attrs = [attr for attr in weight_dq.attribute if attr.name == "axis"] + self.assertEqual(len(axis_attrs), 1, "DequantizeLinear node is missing the 'axis' attribute") + # MatMulInteger quantizes weight with axis=-1 (default in __quantize_inputs). + self.assertEqual(axis_attrs[0].i, -1, f"Expected axis=-1, got axis={axis_attrs[0].i}") + + # The scale initializer must be 1-D with size > 1 (truly per-channel, not collapsed). + scale_name = weight_dq.input[1] + scale_init = next((i for i in q_model.graph.initializer if i.name == scale_name), None) + self.assertIsNotNone(scale_init, f"Scale initializer '{scale_name}' not found") + scale_array = numpy_helper.to_array(scale_init) + self.assertEqual(scale_array.ndim, 1, f"Expected 1-D scale, got shape {scale_array.shape}") + self.assertGreater(scale_array.size, 1, "Scale has only one element; expected per-channel scale") + if __name__ == "__main__": unittest.main(verbosity=2) From eefc22f243c52c8c8f9a05f0e39d04ff499d0a13 Mon Sep 17 00:00:00 2001 From: zz002 Date: Thu, 28 May 2026 18:05:19 -0500 Subject: [PATCH 04/10] fix(partitioning_utils): include Loop/If/Scan implicit inputs in MetaDef (#28608) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary `utils::MakeComputeCapability` is the shared helper used by `utils::CreateSupportedPartitions` to build an `IndexedSubGraph::MetaDef` from a group of supported nodes. When a supported group contains a control-flow op (`Loop`, `If`, `Scan`), `MakeComputeCapability` currently walks only `node->InputDefs()` and silently drops the outer-scope captures (`node->ImplicitInputDefs()`). The captures never enter `meta_def->inputs`, so after `Graph::FinalizeFuseSubGraph` the fused node's `InputDefs()` is missing them — the EP that owns the fused subgraph has no boundary value-info for the captured tensors and cannot bind them at Compute time. This PR adds a second loop in `MakeComputeCapability` that walks `node->ImplicitInputDefs()` with the same "produced inside the partition → skip, otherwise add to subgraph inputs" semantics already applied to `InputDefs()`. ## Why this is the right fix `onnxruntime::Node` partitions inputs into two arrays by design: - `InputDefs()` — formal operand list as declared in the op's ONNX schema. - `ImplicitInputDefs()` — outer-scope SSA values referenced from inside body subgraphs of `Loop` / `If` / `Scan`. These are real boundary inputs at runtime (the body kernel reads them) but they don't appear in the op's formal operand list. `Graph::FinalizeFuseSubGraph` consumes only `meta_def->inputs` to populate the fused node's `InputDefs()` and rewire outer-scope edges. So whatever `MakeComputeCapability` puts in `meta_def->inputs` is what ends up at the fused-node boundary. Omitting `ImplicitInputDefs()` here means the captures are unreachable downstream — there is no other place that can patch them back in. The fix is intentionally a mirror of the existing `InputDefs()` loop (same `Contains(node_outputs, ...)` produced-inside check, same `ordered_subgraph_inputs.push_back` ordering). The new loop runs after the explicit loop so explicit-operand index ordering in `meta_def->inputs` is preserved (EPs that have implicitly relied on `meta_def->inputs[i].name == node.InputDefs()[i].name` for non-control-flow op groups are not perturbed). ## Scope of impact Only EPs that consume `utils::MakeComputeCapability` / `utils::CreateSupportedPartitions` are affected. A quick audit: | EP | Uses `partitioning_utils::MakeComputeCapability`? | Affected by bug? | |---|---|---| | Plugin EPs (`EpGraphSupportInfo_AddNodesToFuse` → `PluginExecutionProvider::GetCapability`) | yes, in `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | **yes** | | `internal_testing_ep` (used by ORT's own unit tests) | yes, in `onnxruntime/test/internal_testing_ep/internal_testing_execution_provider.cc` | **yes** | | TensorRT, MIGraphX, NV-TRT-RTX, VitisAI | no — they build `MetaDef::inputs` themselves and already walk `ImplicitInputDefs()` (e.g. `tensorrt_execution_provider.cc:2084`, `migraphx_execution_provider.cc:735`) | no | | DML / CPU / CUDA / ROCm / OpenVINO / QNN / CANN / WebGPU / CoreML | don't use it for Loop/If/Scan fusion paths | no | So the impact is bounded to the plugin EP architecture (ORT 1.23+) and the in-tree testing EP — both of which delegate boundary calculation to this shared helper. ## Reproduction The bug is reproducible against this repo's `internal_testing_ep`. No external code required. A minimal repro model with a Loop body that captures an outer-scope tensor `B`: ```python # build_repro.py — produces a ~1.5 KB onnx import numpy as np, onnx from onnx import TensorProto, helper as h, numpy_helper as nph A = h.make_tensor_value_info("A", TensorProto.FLOAT, ["N", 2, 2]) B = h.make_tensor_value_info("B", TensorProto.FLOAT, [2, 2]) out = h.make_tensor_value_info("v_final", TensorProto.FLOAT, [2, 2]) acc_init = nph.from_array(np.zeros((2, 2), np.float32), name="acc_init") cond_init = nph.from_array(np.array([1], np.bool_), name="cond_init") sq_ax = nph.from_array(np.array([0], np.int64), name="sq_ax") body = h.make_graph( nodes=[ h.make_node("Gather", ["A", "iter"], ["slice"], axis=0), h.make_node("Add", ["slice", "B"], ["tmp"]), # captures outer B h.make_node("Add", ["acc_in", "tmp"], ["acc_out"]), h.make_node("Identity", ["cond_in"], ["cond_out"]), ], name="loop_body", inputs=[h.make_tensor_value_info("iter", TensorProto.INT64, []), h.make_tensor_value_info("cond_in", TensorProto.BOOL, []), h.make_tensor_value_info("acc_in", TensorProto.FLOAT, [2, 2])], outputs=[h.make_tensor_value_info("cond_out", TensorProto.BOOL, []), h.make_tensor_value_info("acc_out", TensorProto.FLOAT, [2, 2])], ) g = h.make_graph( nodes=[ h.make_node("Shape", ["A"], ["M_1d"], start=0, end=1), h.make_node("Squeeze", ["M_1d", "sq_ax"], ["M"]), h.make_node("Loop", ["M", "cond_init", "acc_init"], ["v_final"], body=body), ], name="loop_with_outer_capture", inputs=[A, B], outputs=[out], initializer=[acc_init, cond_init, sq_ax], ) onnx.save(h.make_model(g, opset_imports=[h.make_opsetid("", 16)]), "loop_with_outer_capture.onnx") ``` Observable bug path (against any EP using `CreateSupportedPartitions`, e.g. `InternalTestingExecutionProvider`): ```cpp // Claim every node (Shape/Squeeze/Constant/Loop) as compiled. SessionOptions so; InferenceSession session(so, env); session.RegisterExecutionProvider( std::make_unique(/*supported=*/{...})); session.Load("loop_with_outer_capture.onnx"); session.Initialize(); // In EP::Compile, iterate fused_node.InputDefs(): // for (const auto* in : fused_node.InputDefs()) std::cerr << in->Name() << "\n"; // BEFORE this fix: only "A" is printed (Shape(A) makes A explicit; // B is consumed only via Loop's ImplicitInputDefs and gets dropped). // AFTER this fix: both "A" and "B" are printed. ``` A small unit-test fixture exercising the same path can be added to `onnxruntime/test/providers/partitioning_utils_test.cc` following the existing `CheckAllNodesProcessed` pattern, asserting that `result[0]->sub_graph->GetMetaDef()->inputs` contains `B` when the supported group includes the Loop. ## What this PR changes A single hunk in `onnxruntime/core/providers/partitioning_utils.cc::MakeComputeCapability`, immediately after the existing `for (const auto* input : node->InputDefs()) { ... }`: ```cpp // Region-bearing ops (Loop/If/Scan) reference outer-scope SSA values via // ImplicitInputDefs rather than InputDefs. When an EP claims the whole // control-flow op, those implicit captures must also be in MetaDef::inputs // so FinalizeFuseSubGraph can rewire the outer-scope edges onto the fused // node's InputDefs. Without this, plugin EPs that fuse Loop/If/Scan lose // the captures at the fused-node boundary and cannot resolve them at // Compute time. for (const auto* input : node->ImplicitInputDefs()) { if (!input->Exists()) { continue; } if (!Contains(node_outputs, input)) { if (!Contains(subgraph_inputs, input)) { subgraph_inputs.insert(input); ordered_subgraph_inputs.push_back(input); } } } ``` ## Risks / migration - **No ABI change.** `MakeComputeCapability` signature unchanged. `IndexedSubGraph::MetaDef` schema unchanged. - **No semantic regression for op groups without control flow.** The new loop only adds elements; for partitions that contain no `Loop` / `If` / `Scan`, `ImplicitInputDefs()` is empty on every node and the new loop is a no-op. - **Behavior change for plugin EPs that fuse Loop/If/Scan.** Their fused node's `InputDefs()` gains the captures. EPs that were silently fishing out captures via a workaround (e.g. walking the original Loop node's `ImplicitInputDefs()` themselves at Compile time) would see those names show up via the standard fused-node `InputDefs()` API. Audit above shows no in-tree EP that uses `partitioning_utils` had such a workaround — TRT / MIGraphX / etc. roll their own MetaDef without calling `MakeComputeCapability`. ## Validation - Verified the fix end-to-end against a downstream plugin EP that claims a `Loop` node as part of a fused partition (Loop body captures an outer-scope tensor): without this fix, the EP cannot resolve the captured tensor name at the fused-node boundary; with the fix the captured tensor appears in `fused_node.InputDefs()` and session initialization + the EP's Compile both succeed. - No `partitioning_utils.cc` changes between `origin/main` and the patch base, so it applies cleanly. - Existing `onnxruntime_test_all --gtest_filter=PartitioningUtilsTest.*` cases still pass (the fix only adds behavior for control-flow ops; non-control-flow partitions are byte-for-byte identical to before). --- .../core/providers/partitioning_utils.cc | 36 ++++-- .../test/providers/partitioning_utils_test.cc | 107 ++++++++++++++++++ 2 files changed, 132 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index f368537d655b2..a1f7e90d7c089 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -307,19 +307,33 @@ std::unique_ptr MakeComputeCapability(const GraphViewer& grap for (const Node* node : group) { sub_graph->nodes.push_back(node->Index()); - for (const auto* input : node->InputDefs()) { - if (!input->Exists()) { - // skip the placeholder inputs - continue; - } - // if the node input was not produced by this subgraph, add it to the subgraph inputs. - if (!Contains(node_outputs, input)) { - if (!Contains(subgraph_inputs, input)) { - subgraph_inputs.insert(input); - ordered_subgraph_inputs.push_back(input); + // Collect boundary inputs from a def container, skipping placeholders and + // values already produced inside the partition; preserves first-seen order. + auto collect_boundary_inputs = [&](const auto& defs) { + for (const auto* input : defs) { + if (!input->Exists()) { + continue; + } + if (!Contains(node_outputs, input)) { + if (!Contains(subgraph_inputs, input)) { + subgraph_inputs.insert(input); + ordered_subgraph_inputs.push_back(input); + } } } - } + }; + + collect_boundary_inputs(node->InputDefs()); + + // Region-bearing ops (Loop/If/Scan) reference outer-scope SSA values via + // ImplicitInputDefs rather than InputDefs. When an EP claims the whole + // control-flow op, those implicit captures must also be in MetaDef::inputs + // so FinalizeFuseSubGraph can rewire the outer-scope edges onto the fused + // node's InputDefs. Without this, plugin EPs that fuse Loop/If/Scan lose + // the captures at the fused-node boundary and cannot resolve them at + // Compute time. Running this after the explicit loop preserves + // explicit-operand index ordering in meta_def->inputs. + collect_boundary_inputs(node->ImplicitInputDefs()); const auto& output_defs = node->OutputDefs(); for (const auto* output_def : output_defs) { diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5f435199679be..75601ef55ffe6 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -208,6 +208,113 @@ TEST(PartitioningUtilsTest, TestQDQNodeGroupWithRedundantRelu) { CheckAllNodesProcessed(build_model); } +// Regression test for the fix that adds Node::ImplicitInputDefs() to MetaDef::inputs +// in utils::MakeComputeCapability. Builds a graph with a Loop whose body captures +// outer-scope tensor "B"; asserts B appears in the fused subgraph's MetaDef::inputs +// and that explicit Loop operands precede the implicit capture. +TEST(PartitioningUtilsTest, TestLoopBodyImplicitInputsInMetaDef) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("loop_capture", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 16}}, {}, logger); + Graph& main_graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto float_2x2; + float_2x2.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + float_2x2.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + float_2x2.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + ONNX_NAMESPACE::TypeProto int64_scalar; + int64_scalar.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_INT64); + int64_scalar.mutable_tensor_type()->mutable_shape(); + + ONNX_NAMESPACE::TypeProto bool_scalar; + bool_scalar.mutable_tensor_type()->set_elem_type( + ONNX_NAMESPACE::TensorProto_DataType_BOOL); + bool_scalar.mutable_tensor_type()->mutable_shape(); + + auto build_body = [&]() -> ONNX_NAMESPACE::GraphProto { + Model body_model("loop_body", true, logger); + Graph& body = body_model.MainGraph(); + + auto& iter = body.GetOrCreateNodeArg("iter", &int64_scalar); + auto& cond_in = body.GetOrCreateNodeArg("cond_in", &bool_scalar); + auto& acc_in = body.GetOrCreateNodeArg("acc_in", &float_2x2); + + // Outer-scope capture B used inside the body Add. + ORT_IGNORE_RETURN_VALUE(body.GetOrCreateNodeArg("B", &float_2x2)); + body.AddOuterScopeNodeArg("B"); + auto& B_in_body = *body.GetNodeArg("B"); + + auto& acc_out = body.GetOrCreateNodeArg("acc_out", &float_2x2); + body.AddNode("body_add", "Add", "acc + B", {&acc_in, &B_in_body}, {&acc_out}); + + auto& cond_out = body.GetOrCreateNodeArg("cond_out", &bool_scalar); + body.AddNode("body_cond_id", "Identity", "forward cond", {&cond_in}, {&cond_out}); + + body.SetInputs({&iter, &cond_in, &acc_in}); + body.SetOutputs({&cond_out, &acc_out}); + EXPECT_STATUS_OK(body.Resolve()); + return body.ToGraphProto(); + }; + + auto& M = main_graph.GetOrCreateNodeArg("M", &int64_scalar); + auto& cond_init = main_graph.GetOrCreateNodeArg("cond_init", &bool_scalar); + auto& acc_init = main_graph.GetOrCreateNodeArg("acc_init", &float_2x2); + auto& B = main_graph.GetOrCreateNodeArg("B", &float_2x2); + auto& v_final = main_graph.GetOrCreateNodeArg("v_final", &float_2x2); + + auto& loop_node = main_graph.AddNode( + "loop", "Loop", "Loop with outer-scope capture", + {&M, &cond_init, &acc_init}, {&v_final}); + loop_node.AddAttribute("body", build_body()); + + main_graph.SetInputs({&M, &cond_init, &acc_init, &B}); + main_graph.SetOutputs({&v_final}); + ASSERT_STATUS_OK(main_graph.Resolve()); + + GraphViewer graph_viewer(main_graph); + std::vector> node_unit_holder; + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); + + const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; }; + const auto on_group_closed = [&](const std::vector& /*group*/) -> bool { return true; }; + const auto gen_metadef_name = [&]() { + static int id = 0; + return "TestMetaDef_loop_capture_" + std::to_string(id++); + }; + + auto result = utils::CreateSupportedPartitions( + graph_viewer, is_node_supported, on_group_closed, + gen_metadef_name, "TEST", kCpuExecutionProvider, + &node_unit_map, /*drop_constant_initializers=*/true); + + ASSERT_EQ(result.size(), size_t(1)); + const auto* meta_def = result[0]->sub_graph->GetMetaDef(); + ASSERT_NE(meta_def, nullptr); + + const auto& inputs = meta_def->inputs; + + // Explicit Loop operands. + EXPECT_THAT(inputs, ::testing::Contains("M")); + EXPECT_THAT(inputs, ::testing::Contains("cond_init")); + EXPECT_THAT(inputs, ::testing::Contains("acc_init")); + // Outer-scope capture used only via ImplicitInputDefs; before the fix this + // was silently dropped from meta_def->inputs, leaving the fused node's + // InputDefs() unable to resolve B at Compute time. + EXPECT_THAT(inputs, ::testing::Contains("B")); + + const auto last_explicit = std::find(inputs.begin(), inputs.end(), "acc_init"); + const auto first_implicit = std::find(inputs.begin(), inputs.end(), "B"); + ASSERT_NE(last_explicit, inputs.end()); + ASSERT_NE(first_implicit, inputs.end()); + EXPECT_LT(last_explicit, first_implicit) + << "explicit Loop operands must precede implicit captures in meta_def->inputs"; +} + TEST(PartitioningUtilsTest, TestQDQNodeGroupWithRedundantClip) { const auto build_model = [](ModelTestBuilder& builder) { auto* input_0_arg = builder.MakeInput({2, 3, 3, 3}, std::numeric_limits::min(), From 77d32c3a14fe9c8b60589a58c03111dc6a9e9743 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 28 May 2026 17:03:22 -0700 Subject: [PATCH 05/10] Address external path references in loaded models (#28709) This pull request strengthens security checks around loading external tensor data in ONNX Runtime, particularly to prevent malicious models from referencing unsafe file paths or in-memory address markers that could lead to arbitrary file access or unsafe memory dereferencing. The changes introduce stricter validation for external data paths and add explicit rejections for ORT in-memory address markers found in model protobufs, along with new and improved regression tests to verify this behavior. **Security hardening for external data loading:** * Added `ValidateExternalFilePathForTensor` to enforce that external data paths are validated for all code paths loading external data (including those outside `Graph::Resolve`), rejecting absolute or directory-escaping paths and passing through only trusted in-memory markers. This is now called in `GetExtDataFromTensorProto` and `LoadExtDataToTensorFromTensorProto` to ensure defense-in-depth. [[1]](diffhunk://#diff-d31e9fbe0f5334fcd949833e035f2b25d5ae810dcd505c545f6b372b546b1406R1568-R1596) [[2]](diffhunk://#diff-d31e9fbe0f5334fcd949833e035f2b25d5ae810dcd505c545f6b372b546b1406R1760-R1762) * Updated the validation logic for sparse tensor sub-tensors with `ValidateSparseSubTensorExternalDataPath`, clarifying the handling of in-memory markers and ensuring only legitimate file paths are accepted. * Changed `SparseTensorProtoToDenseTensorProto` to use the new sparse sub-tensor validation for both values and indices. **Model loading and graph construction protections:** * In `Graph::Graph`, added explicit rejection of ORT in-memory address markers in sparse tensor attributes and initializers when loading from a protobuf, preventing attackers from crafting models that could cause unsafe memory access during sparse-to-dense conversion or initializer resolution. [[1]](diffhunk://#diff-e231a92b40d89409cc8e82436be0a15bc87ef95c93b303b9feaeab6e50c8835cR1268-R1282) [[2]](diffhunk://#diff-e231a92b40d89409cc8e82436be0a15bc87ef95c93b303b9feaeab6e50c8835cR1322-R1331) [[3]](diffhunk://#diff-e231a92b40d89409cc8e82436be0a15bc87ef95c93b303b9feaeab6e50c8835cR1373-R1380) **Expanded and improved testing:** * Added new unit tests to verify that absolute and directory-escaping external paths are rejected even when loading tensors directly (not via graph resolution), and that in-memory address markers are not accepted in dense or sparse initializers loaded from protobufs. [[1]](diffhunk://#diff-d75ec5db9cc4642f78b6ff568aff6d10398fc211b0fb7c862d3ec88738e3eda6R1156-R1217) [[2]](diffhunk://#diff-1d3978c99d95a56af0f2603bdd0b10cf02bdc1cecbd4fe5db353a8c8388696efR1365-R1484) * Updated an optimizer initializer test to reflect the new error handling for invalid external data paths. --- .../core/framework/tensorprotoutils.cc | 83 ++++++++--- onnxruntime/core/graph/graph.cc | 33 +++++ .../core/graph/graph_flatbuffers_utils.cc | 14 +- .../test/framework/tensorutils_test.cc | 129 ++++++++++++++++++ onnxruntime/test/ir/graph_test.cc | 118 ++++++++++++++++ .../test/optimizer/initializer_test.cc | 5 +- 6 files changed, 358 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 275fa837a7257..e9775fe23fe08 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1565,6 +1565,23 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p } #endif +// Backstop validation for callers that load external data outside Graph::Resolve (e.g. training +// checkpoints, custom-op initializers). Passes through ORT's in-memory address markers — those are +// validated at higher layers (Graph::ConvertInitializersIntoOrtValues for dense; markers on sparse +// sub-tensors are rejected outright in SparseTensorProtoToDenseTensorProto). For declared file paths, +// defers to ValidateExternalDataPath, which rejects absolute paths and paths that escape the model +// directory. Callers must have already verified the tensor has external data. +static Status ValidateExternalFilePathForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path) { + if (HasExternalDataInMemory(tensor_proto)) { + return Status::OK(); + } + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + return utils::ValidateExternalDataPath(model_path, external_data_info->GetRelPath()); +} + Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, @@ -1572,6 +1589,11 @@ Status GetExtDataFromTensorProto(const Env& env, ORT_ENFORCE(HasExternalData(tensor_proto), "TensorProto for: ", tensor_proto.name(), "Expected to have external data"); + // Defense-in-depth: reject absolute or directory-escaping external data paths even when this + // function is reached outside Graph::Resolve (e.g. training checkpoint load, custom-op init). + // In-memory address markers are passed through; their validity is enforced upstream. + ORT_RETURN_IF_ERROR(ValidateExternalFilePathForTensor(tensor_proto, model_path)); + std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -1735,6 +1757,9 @@ Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem: const IExternalDataLoader& ext_data_loader, Tensor& tensor) { ORT_ENFORCE(HasExternalData(tensor_proto)); + // Defense-in-depth path validation for callers reaching this function outside Graph::Resolve. + // In-memory markers are passed through; rejected explicitly below as unsupported for this path. + ORT_RETURN_IF_ERROR(ValidateExternalFilePathForTensor(tensor_proto, model_path)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -2098,30 +2123,29 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { #if !defined(DISABLE_SPARSE_TENSORS) -// Validates that a TensorProto's external data path does not escape the model directory. -// Also validates that the file exists when filesystem access is available (skipped on WASM without a virtual FS). -// Returns Status::OK() (no-op) for tensors that do not use file-based external data. -static Status ValidateExternalDataPathForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& model_path) { - // Gates on data_location == EXTERNAL directly instead of using HasExternalData()/HasExternalDataInFile(), - // which also require data_type != UNDEFINED. That check is appropriate for data processing (can't unpack - // without a type), but too narrow for security validation: we must validate any declared external path - // regardless of data_type. - if (tensor_proto.data_location() != ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { +// Validates the external data declaration on a sub-tensor of a SparseTensorProto (values or +// indices). Validates that any file path stays within the model directory. +// +// Gates on data_location == EXTERNAL (rather than HasExternalData()) so that path validation +// runs even when data_type is UNDEFINED. A malicious model could set data_location=EXTERNAL with +// data_type=UNDEFINED and an evil file path; downstream loading would also reject it, but we +// validate here for defense-in-depth. +// +// In-memory address markers must never appear on sparse sub-tensors. The trusted .ort loader +// materializes sparse sub-tensors as inline raw_data (see LoadSparseInitializerOrtFormat); the +// untrusted .onnx protobuf path rejects markers at the Graph constructor; and +// SparseTensorProtoToDenseTensorProto re-asserts the invariant before this function is reached. +// The HasExternalDataInMemory early-return below is a paranoid backstop. +static Status ValidateSparseSubTensorExternalDataPath(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path) { + if (tensor_proto.data_location() != ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL || + HasExternalDataInMemory(tensor_proto)) { return Status::OK(); } std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - const auto& rel_path = external_data_info->GetRelPath(); - - // In-memory external data uses special marker locations — skip file path validation for those. - if (rel_path == kTensorProtoLittleEndianMemoryAddressTag || - rel_path == kTensorProtoNativeEndianMemoryAddressTag) { - return Status::OK(); - } - - return utils::ValidateExternalDataPath(model_path, rel_path); + return utils::ValidateExternalDataPath(model_path, external_data_info->GetRelPath()); } static Status CopySparseData(const std::string& name, @@ -2303,6 +2327,23 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT const auto& sparse_values = sparse.values(); const auto& name = sparse_values.name(); + // In-memory address markers (pointing into mmap'd / heap buffers) are forbidden on sparse + // sub-tensors. The trusted .ort loader is required to materialize sparse sub-tensors as inline + // raw_data (see LoadSparseInitializerOrtFormat) so they never carry markers. Untrusted .onnx + // protobuf input is rejected at the Graph constructor before reaching this function; this is + // the function-level backstop. A marker here would otherwise trigger an arbitrary memory read + // in UnpackInitializerData. + if (HasExternalDataInMemory(sparse_values)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " values use an in-memory address marker which is not permitted on sparse sub-tensors."); + } + if (HasExternalDataInMemory(sparse.indices())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " indices use an in-memory address marker which is not permitted on sparse sub-tensors."); + } + const auto values_rank = sparse_values.dims_size(); if (values_rank != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, @@ -2371,8 +2412,8 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT // Validate external data paths before any early returns or allocations. // This ensures malicious paths are rejected even for zero-element tensors, // and prevents large allocations before an invalid path is caught. - ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(sparse_values, model_path)); - ORT_RETURN_IF_ERROR(ValidateExternalDataPathForTensor(indices, model_path)); + ORT_RETURN_IF_ERROR(ValidateSparseSubTensorExternalDataPath(sparse_values, model_path)); + ORT_RETURN_IF_ERROR(ValidateSparseSubTensorExternalDataPath(indices, model_path)); if (dense_elements == 0) { // if there are no elements in the dense tensor, we can return early with an empty tensor proto diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 30df0e23af6ae..24d49f5f3f247 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1265,6 +1265,21 @@ Graph::Graph(const Model& owning_model, continue; } +#if !defined(DISABLE_SPARSE_TENSORS) + // Reject ORT in-memory address markers on a sparse-tensor Constant attribute before the + // sparse-to-dense conversion runs — those markers are an in-process ORT sentinel and must + // never appear in a deserialized protobuf. See note on the dense initializer loop below. + if (node.attribute_size() > 0 && + node.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + const auto& s = node.attribute(0).sparse_tensor(); + ORT_ENFORCE(!utils::HasExternalDataInMemory(s.values()) && + !utils::HasExternalDataInMemory(s.indices()), + "Constant node '", node.name(), + "' sparse-tensor attribute references an ORT in-memory address marker, " + "which is not allowed in a model protobuf."); + } +#endif + const gsl::not_null tensor{graph_proto_->add_initializer()}; ORT_THROW_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node, model_path, *tensor)); @@ -1304,6 +1319,16 @@ Graph::Graph(const Model& owning_model, if (graph_proto_->sparse_initializer_size() > 0) { for (const auto& sparse_tensor : graph_proto_->sparse_initializer()) { ORT_ENFORCE(utils::HasName(sparse_tensor), "Sparse initializer must have a name. This model is invalid"); + // Reject ORT's in-memory address markers on sparse sub-tensors arriving via the protobuf + // path. Such markers are an internal ORT optimization set by trusted loaders (e.g. ORT-format + // flatbuffer load) and must never appear in a SparseTensorProto deserialized from an .onnx + // protobuf; if they do, the model is crafted and would cause ORT to dereference an + // attacker-supplied pointer during sparse-to-dense conversion. + for (const auto* sub : {&sparse_tensor.values(), &sparse_tensor.indices()}) { + ORT_ENFORCE(!utils::HasExternalDataInMemory(*sub), + "Sparse initializer '", sparse_tensor.values().name(), + "' references an ORT in-memory address marker, which is not allowed in a model protobuf."); + } const gsl::not_null tensor{graph_proto_->add_initializer()}; auto status = utils::SparseTensorProtoToDenseTensorProto(sparse_tensor, model_path, *tensor); ORT_ENFORCE(status.IsOK(), status.ToString()); @@ -1345,6 +1370,14 @@ Graph::Graph(const Model& owning_model, // Copy initial tensors to a map. for (auto& tensor : graph_proto_->initializer()) { + // ORT in-memory address markers are an in-process sentinel: they can only be planted by ORT + // itself (e.g. when constructing a TensorProto that aliases an mmap'd .ort buffer or an OrtValue). + // They must never appear in a TensorProto deserialized from an .onnx protobuf — if they do, the + // model is crafted and would cause ORT to dereference an attacker-supplied pointer when + // resolving the initializer. + ORT_ENFORCE(!utils::HasExternalDataInMemory(tensor), + "Initializer '", tensor.name(), + "' references an ORT in-memory address marker, which is not allowed in a model protobuf."); auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name() diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index c51f24229f145..0fe021cec88d3 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -418,17 +418,27 @@ Status LoadSparseInitializerOrtFormat(const fbs::SparseTensor& fbs_sparse_tensor SparseTensorProto& initializer, const OrtFormatLoadOptions& load_options) { SparseTensorProto loaded_initializer; + + // Sparse sub-tensors must never carry the in-memory address marker. The marker would point into + // the mmap'd flatbuffer buffer; allowing it here would force every downstream consumer of the + // sparse->dense conversion to validate the marker, and would conflate the trust boundary + // (sparse markers from untrusted .onnx input are an arbitrary-memory-read vector). Force the + // inner loader to materialize a normal inline raw_data copy regardless of size; the cost is + // small because sparse->dense conversion immediately copies the bytes again. + OrtFormatLoadOptions sub_tensor_options = load_options; + sub_tensor_options.can_use_flatbuffer_for_initializers = false; + auto fbs_values_tensor = fbs_sparse_tensor.values(); ORT_RETURN_IF(nullptr == fbs_values_tensor, "Missing values for sparse initializer. Invalid ORT format model."); auto* values_tensor = loaded_initializer.mutable_values(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, load_options)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_values_tensor, *values_tensor, sub_tensor_options)); ORT_RETURN_IF(values_tensor->name().empty(), "Missing name for SparseTensor initializer. Invalid ORT format model."); auto fbs_indicies_tensor = fbs_sparse_tensor.indices(); ORT_RETURN_IF(nullptr == fbs_indicies_tensor, "Missing indicies for sparse initializer: ", "'", values_tensor->name(), "'", "Invalid ORT format model."); auto* indicies_tensor = loaded_initializer.mutable_indices(); - ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, load_options)); + ORT_RETURN_IF_ERROR(LoadInitializerOrtFormat(*fbs_indicies_tensor, *indicies_tensor, sub_tensor_options)); auto fbs_dims = fbs_sparse_tensor.dims(); ORT_RETURN_IF(nullptr == fbs_dims, "Missing dims for sparse initializer: ", "'", values_tensor->name(), "'", diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 06cc3ea6ad8d2..fa34e9722b66b 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -1151,8 +1151,137 @@ TEST_F(PathValidationTest, SparseTensorExternalDataPathTraversalBlocked_ZeroNNZ) EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); } +// Defense-in-depth: SparseTensorProtoToDenseTensorProto must reject ORT's in-memory address +// marker on sparse sub-tensors unconditionally. The trusted .ort loader is required to +// materialize sparse sub-tensors as inline raw_data so they never carry markers. Without this +// self-check, a caller that bypasses the Graph-ctor chokepoint would dereference an +// attacker-controlled address. +TEST(SparseTensorProtoToDenseTensorProtoMarkerTest, RejectsInMemoryMarkerOnValuesByDefault) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("sparse_marker_values"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = values->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = values->add_external_data(); + off->set_key("offset"); + off->set_value("0"); + auto* len = values->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->add_int64_data(0); + indices->add_int64_data(1); + + ONNX_NAMESPACE::TensorProto dense; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, std::filesystem::path{}, dense); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("in-memory address marker")); +} + +TEST(SparseTensorProtoToDenseTensorProtoMarkerTest, RejectsInMemoryMarkerOnIndicesByDefault) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.add_dims(4); + + auto* values = sparse.mutable_values(); + values->set_name("sparse_marker_indices"); + values->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + values->add_dims(2); + values->add_float_data(1.0f); + values->add_float_data(2.0f); + + auto* indices = sparse.mutable_indices(); + indices->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + indices->add_dims(2); + indices->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = indices->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = indices->add_external_data(); + off->set_key("offset"); + off->set_value("0"); + auto* len = indices->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(int64_t))); + + ONNX_NAMESPACE::TensorProto dense; + Status status = utils::SparseTensorProtoToDenseTensorProto(sparse, std::filesystem::path{}, dense); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("in-memory address marker")); +} + #endif // !defined(DISABLE_SPARSE_TENSORS) +// Defense-in-depth: GetExtDataFromTensorProto must reject absolute external paths even when +// called with an empty model_path (e.g. from training checkpoint or custom-op init paths). +// Previously, ValidateExternalDataPath was only invoked from Graph::ConvertInitializersIntoOrtValues, +// so direct callers of GetExtDataFromTensorProto could load arbitrary files. +TEST(GetExtDataFromTensorProtoTest, RejectsAbsoluteExternalPathWithEmptyModelPath) { + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name("abs_external"); + tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_proto.add_dims(2); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = tensor_proto.add_external_data(); + loc->set_key("location"); +#ifdef _WIN32 + loc->set_value("C:\\data.bin"); +#else + loc->set_value("/etc/passwd"); +#endif + + auto* off = tensor_proto.add_external_data(); + off->set_key("offset"); + off->set_value("0"); + + auto* len = tensor_proto.add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + OrtValue value; + Status status = utils::GetExtDataFromTensorProto(Env::Default(), {}, tensor_proto, value); + ASSERT_FALSE(status.IsOK()) << "Absolute external path must be rejected even with empty model_path."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Absolute path not allowed")); +} + +// Defense-in-depth: GetExtDataFromTensorProto must reject directory-escaping external paths even +// when the caller passes a non-empty model_path. This guards callers outside Graph::Resolve. +TEST(GetExtDataFromTensorProtoTest, RejectsEscapingExternalPath) { + ONNX_NAMESPACE::TensorProto tensor_proto; + tensor_proto.set_name("escape_external"); + tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_proto.add_dims(2); + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + auto* loc = tensor_proto.add_external_data(); + loc->set_key("location"); + loc->set_value("../escape.bin"); + + auto* off = tensor_proto.add_external_data(); + off->set_key("offset"); + off->set_value("0"); + + auto* len = tensor_proto.add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(2 * sizeof(float))); + + OrtValue value; + // Pass a synthetic model_path so the validator has a model directory to compare against. + std::filesystem::path model_path = std::filesystem::temp_directory_path() / "sub" / "model.onnx"; + Status status = utils::GetExtDataFromTensorProto(Env::Default(), model_path, tensor_proto, value); + ASSERT_FALSE(status.IsOK()) << "Directory-escaping external path must be rejected."; + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("escapes")); +} + TEST(TensorProtoUtilsTest, GetNodeProtoLayeringAnnotation) { // Case 1: Annotation exists { diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 1256a39bcd0c7..019f15a46abc5 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1362,8 +1362,126 @@ TEST_F(GraphTest, UnusedSparseInitializerIsIgnored) { auto& graph_proto = graph2.ToGraphProto(); ASSERT_TRUE(graph_proto.sparse_initializer().empty()); } + +// Regression test for issue #28617: a SparseTensorProto loaded from a model protobuf must not +// be allowed to carry an ORT in-memory address marker on its values or indices sub-tensors. +// Those markers are an ORT-internal mechanism for trusted in-memory buffers (.ort flatbuffer +// load). Accepting them on a crafted .onnx protobuf would let the model make ORT dereference +// an attacker-supplied pointer during sparse-to-dense conversion. +static void RunRejectInMemoryMarkerOnSparseInitializerTest(bool marker_on_indices, + const onnxruntime::logging::Logger& logger) { + Model model("RejectInMemoryMarkerOnSparseInitializer", false, logger); + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + + auto* m_sparse_initializer = m_graph->add_sparse_initializer(); + ConstructSparseTensor("in_memory_marker_sparse", *m_sparse_initializer); + + // Overwrite either values or indices to declare external data pointing at an in-memory marker. + // Allocate a real backing buffer so even an accidental dereference of "offset" stays in-process. + static std::vector backing(64, 0); + auto* sub = marker_on_indices ? m_sparse_initializer->mutable_indices() + : m_sparse_initializer->mutable_values(); + sub->clear_raw_data(); + sub->clear_int64_data(); + sub->clear_float_data(); + sub->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = sub->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = sub->add_external_data(); + off->set_key("offset"); + off->set_value(std::to_string(reinterpret_cast(backing.data()))); + auto* len = sub->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(backing.size())); + + std::string s1; + model_proto.SerializeToString(&s1); + + ModelProto model_proto_1; + ASSERT_TRUE(model_proto_1.ParseFromString(s1)); + + std::shared_ptr p_tmp_model; + // The Graph ctor must reject the marker — Model::Load is expected to return a non-OK status + // (Graph ctor's ORT_THROW is caught at the C++/Status boundary). + ORT_TRY { + auto status = onnxruntime::Model::Load(model_proto_1, p_tmp_model, nullptr, logger); + EXPECT_FALSE(status.IsOK()) << "Loading a model with an in-memory marker on a sparse " + << (marker_on_indices ? "indices" : "values") + << " sub-tensor must fail."; + if (!status.IsOK()) { + EXPECT_THAT(status.ErrorMessage(), + ::testing::HasSubstr("in-memory address marker")); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + EXPECT_THAT(std::string(ex.what()), + ::testing::HasSubstr("in-memory address marker")); + }); + } +} + +TEST_F(GraphTest, RejectInMemoryMarkerOnSparseInitializerValues) { + RunRejectInMemoryMarkerOnSparseInitializerTest(/*marker_on_indices=*/false, *logger_); +} + +TEST_F(GraphTest, RejectInMemoryMarkerOnSparseInitializerIndices) { + RunRejectInMemoryMarkerOnSparseInitializerTest(/*marker_on_indices=*/true, *logger_); +} #endif // !defined(DISABLE_SPARSE_TENSORS) +// Regression test: ORT in-memory address markers are an in-process sentinel only; they must never +// appear in a dense initializer deserialized from an .onnx protobuf. The Graph ctor must reject +// such a model. +TEST_F(GraphTest, RejectInMemoryMarkerOnDenseInitializer) { + Model model("RejectInMemoryMarkerOnDenseInitializer", false, *logger_); + auto model_proto = model.ToProto(); + auto* m_graph = model_proto.mutable_graph(); + ConstructASimpleAddGraph(*m_graph, nullptr); + + static std::vector backing(64, 0); + + auto* init = m_graph->add_initializer(); + init->set_name("in_memory_marker_dense"); + init->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + init->add_dims(static_cast(backing.size() / sizeof(float))); + init->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto* loc = init->add_external_data(); + loc->set_key("location"); + loc->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoLittleEndianMemoryAddressTag)); + auto* off = init->add_external_data(); + off->set_key("offset"); + off->set_value(std::to_string(reinterpret_cast(backing.data()))); + auto* len = init->add_external_data(); + len->set_key("length"); + len->set_value(std::to_string(backing.size())); + + std::string s1; + model_proto.SerializeToString(&s1); + + ModelProto model_proto_1; + ASSERT_TRUE(model_proto_1.ParseFromString(s1)); + + std::shared_ptr p_tmp_model; + ORT_TRY { + auto status = onnxruntime::Model::Load(model_proto_1, p_tmp_model, nullptr, *logger_); + EXPECT_FALSE(status.IsOK()) << "Loading a model with an in-memory marker on a dense initializer must fail."; + if (!status.IsOK()) { + EXPECT_THAT(status.ErrorMessage(), + ::testing::HasSubstr("in-memory address marker")); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + EXPECT_THAT(std::string(ex.what()), + ::testing::HasSubstr("in-memory address marker")); + }); + } +} + TEST_F(GraphTest, GraphConstruction_CheckIsNotAcyclic) { // A cyclic graph // SouceNode diff --git a/onnxruntime/test/optimizer/initializer_test.cc b/onnxruntime/test/optimizer/initializer_test.cc index 391942acfca35..6f340e9a9b734 100644 --- a/onnxruntime/test/optimizer/initializer_test.cc +++ b/onnxruntime/test/optimizer/initializer_test.cc @@ -96,7 +96,10 @@ TEST(OptimizerInitializerTest, LoadExternalData) { // bad model paths EXPECT_THROW(Initializer i(tensor_proto_base, std::filesystem::path()), OnnxRuntimeException); - EXPECT_THROW(Initializer i(tensor_proto_base, ORT_TSTR("invalid/directory")), std::filesystem::filesystem_error); + // ValidateExternalDataPath in GetExtDataFromTensorProto now rejects this earlier with an + // ORT error ("External data path does not exist") instead of letting a downstream + // std::filesystem call throw filesystem_error. + EXPECT_THROW(Initializer i(tensor_proto_base, ORT_TSTR("invalid/directory")), OnnxRuntimeException); // bad length { From 7e82d19d0c8a4085c5dd137ad157ad0f02e04bf8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 28 May 2026 18:18:59 -0700 Subject: [PATCH 06/10] Flash Attention style tiled computation for CPU GQA quantized KV cache (#28695) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Add a flash attention-style tiled computation path to the CPU GroupQueryAttention operator for quantized KV cache (INT8/INT4). Instead of materializing the full `[B, N, S, T]` attention probability matrix, this processes K/V in L2-cache-sized blocks with online softmax — reducing peak memory from O(S×T) to O(S×Bc) per head where Bc is the KV block size. Additionally, implements **flash decoding** for the decode phase (S=1): when `batch×heads < threads`, idle threads are repurposed to partition the KV sequence across parallel workers. Each worker computes partial softmax statistics on its KV chunk, then a lightweight reduce phase merges the partials — achieving 2–5x decode speedup for long sequences. ### Motivation For long-sequence LLM inference with quantized KV cache on CPU: - **Prefill**: The full attention matrix allocation becomes a significant memory bottleneck. With 16 heads and S=4096, the naive path allocates ~1 GB for attention scores alone. The tiled approach reduces peak memory by 13–24x and latency by 1.2–2.7x. - **Decode**: When batch size is small relative to available threads, many threads sit idle. Flash decoding partitions the KV sequence across these idle threads, achieving 2–5x speedup for long KV lengths. ## Key Changes | File | Change | |------|--------| | `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` | MLAS kernel: tiled prefill with online softmax, flash decoding (two-phase KV partitioning), and reduce | | `onnxruntime/core/mlas/inc/mlas_qkv_quant.h` | `MlasFlashAttentionQuantizedKVArgs` struct with `flash_decoding_partials` and `kv_chunk_count` fields | | `onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h` | `ApplyAttentionQuantizedFlash()` with L2-cache-aware block sizing, KV concat, flash decoding setup | | `onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc` | Dispatch logic: activates flash path when no softcap/smooth softmax/output_qk | | `cmake/onnxruntime_mlas.cmake` | Added `flashattn_qkv.cpp` to the MLAS build | | `docs/contrib_ops/cpu/gqa.md` | Documentation with algorithm details, benchmark results, and reproduction steps | | `onnxruntime/test/mlas/bench/bench_qkv_quant.cpp` | MLAS-level C++ benchmark (`BM_GQA_Naive` vs `BM_GQA_Flash`) | | `onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py` | Operator-level Python benchmark | ## Algorithm ### Prefill (S > 1): Tiled Flash Attention Per (batch, head, q_block) tile: 1. **QK GEMM** — `MlasQKGemm` on a block slice of quantized K cache 2. **Causal + local window masking** — Set masked positions to -inf before softmax 3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old - m_new)` 4. **SV accumulation** — Dequantize V block to FP32, then accumulate weighted V into output ### Decode (S = 1): Flash Decoding When `sequence_length == 1 && batch_size * num_heads < thread_count && kv_chunk_count > 1`: **Phase 1 — Parallel KV scan**: Each idle thread processes a disjoint KV chunk for a (batch, head) pair. For each chunk: compute QK dot products, find local max, compute local softmax sum, and accumulate partial weighted V output. Store per-chunk `(max_score, sum_exp, partial_output[head_size])` into a partials buffer. **Phase 2 — Reduce**: One thread per (batch, head) merges all chunk partials using the log-sum-exp trick: find global max, rescale each chunk's sum and partial output, then normalize by global sum. This is analogous to GPU flash decoding (Dao et al.) but adapted for CPU threading. ### Activation Conditions Flash path activates when ALL of: - `ORT_GQA_DISABLE_FLASH_ATTENTION` env var is not set - `total_sequence_length > 1` - No softcap, no smooth softmax, no output_qk (attention bias IS supported) Flash decoding additionally requires: - `sequence_length == 1` (decode phase) - `batch_size * num_heads < thread_count` (idle threads available) - `kv_chunk_count > 1` (enough KV to partition) ## Benchmark Results Measured on Intel Xeon Platinum 8480C, 96 CPUs, threads=8. MLAS-level C++ benchmark. ### Latency — Prefill (S = T) Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128. | Seq Length | Naive (ms) | Flash (ms) | Speedup | Quant | |---:|---:|---:|---:|:---| | 512 | 9.9 | 8.1 | 1.2x | per-tensor | | 1024 | 44.4 | 27.0 | 1.6x | per-tensor | | 2048 | 190.9 | 116.9 | 1.6x | per-tensor | | 4096 | 1257.8 | 461.6 | 2.7x | per-tensor | | 512 | 10.7 | 10.8 | 1.0x | per-channel | | 1024 | 49.5 | 41.7 | 1.2x | per-channel | | 2048 | 212.1 | 164.1 | 1.3x | per-channel | | 4096 | 1223.9 | 607.8 | 2.0x | per-channel | ### Latency — Decode (S = 1, no flash decoding) Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128. Flash decoding NOT active (batch×heads=16 > threads=8). | Total Seqlen | Naive (us) | Flash (us) | Speedup | Quant | |---:|---:|---:|---:|:---| | 512 | 32 | 22 | 1.4x | per-tensor | | 1024 | 71 | 47 | 1.5x | per-tensor | | 2048 | 120 | 87 | 1.4x | per-tensor | | 4096 | 210 | 174 | 1.2x | per-tensor | | 512 | 53 | 31 | 1.7x | per-channel | | 1024 | 86 | 52 | 1.7x | per-channel | | 2048 | 172 | 97 | 1.8x | per-channel | | 4096 | 299 | 191 | 1.6x | per-channel | ### Latency — Flash Decoding (S = 1, KV partitioned across threads) Shape: B=1, num_heads=4, kv_num_heads=4 (MHA), head_size=128. Flash decoding IS active (batch×heads=4 < threads=8). | Total Seqlen | Naive (us) | Flash (us) | Speedup | Quant | |---:|---:|---:|---:|:---| | 512 | 31 | 25 | 1.2x | per-tensor | | 1024 | 41 | 25 | 1.6x | per-tensor | | 2048 | 67 | 34 | 2.0x | per-tensor | | 4096 | 197 | 54 | 3.7x | per-tensor | | 512 | 25 | 28 | 0.9x | per-channel | | 1024 | 72 | 27 | 2.7x | per-channel | | 2048 | 144 | 37 | 3.9x | per-channel | | 4096 | 304 | 60 | 5.1x | per-channel | ### Peak Memory — Prefill | Seq Length | Naive Peak | Flash Peak | Memory Reduction | |---:|---:|---:|---:| | 2048 (N=16) | +294 MB | +44 MB | 6.7x | | 4096 (N=16) | +1107 MB | +82 MB | 13.5x | | 4096 (N=32) | +2131 MB | +87 MB | 24.5x | **Summary**: Prefill gains 1.2–2.7x latency + 7–24x memory reduction from tiled online softmax. Decode gains 1.2–1.8x from fused dequant+dot alone. Flash decoding adds 2–5x for long sequences when idle threads are available to partition the KV scan. ### How to Reproduce ```bash # Build ORT python tools/ci_build/build.py --build_dir build/cpu --config Release \ --parallel --build_wheel --skip_tests # MLAS-level C++ benchmark: cd build/cpu/Release ./onnxruntime_mlas_benchmark \ --benchmark_filter='BM_GQA_(Naive|Flash)' \ --benchmark_min_time=0.5s \ --benchmark_repetitions=3 \ --benchmark_report_aggregates_only=true ``` ## Testing - All 35 CPU `GroupQueryAttentionTest.*` tests pass (INT8/INT4, per-tensor/per-channel, multi-batch, large head, GQA ratio variants) - Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to verify fallback path still works - End-to-end verified with `quantized_kv_cache_cpu_demo.py` - Numerical agreement between flash and naive paths: max diff < 1e-7 --- cmake/onnxruntime_mlas.cmake | 1 + docs/contrib_ops/cpu/gqa.md | 214 +++++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 361 +++++++++- .../cpu/bert/group_query_attention.cc | 25 +- onnxruntime/core/mlas/inc/mlas_qkv_quant.h | 75 ++- onnxruntime/core/mlas/lib/flashattn_qkv.cpp | 622 ++++++++++++++++++ onnxruntime/core/mlas/lib/qkv_quant.cpp | 23 +- onnxruntime/core/mlas/lib/qkv_quant_kernel.h | 5 +- .../core/mlas/lib/qkv_quant_kernel_avx2.cpp | 104 ++- .../mlas/lib/qkv_quant_kernel_avx512vnni.cpp | 104 ++- .../core/mlas/lib/qkv_quant_kernel_neon.cpp | 44 +- .../test/mlas/bench/bench_qkv_quant.cpp | 294 ++++++++- .../test/mlas/unittest/test_qkv_quant.cpp | 341 +++++++++- .../transformers/benchmark_gqa_cpu_flash.py | 300 +++++++++ .../transformers/test_gqa_cpu_quantized.py | 356 +++++++++- 15 files changed, 2762 insertions(+), 107 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/flashattn_qkv.cpp create mode 100644 onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 8c7df780735f1..bc64d394b6062 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -55,6 +55,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qlutgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/flashattn_qkv.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 0a144132b5c86..e5a211c9fd11a 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,6 +17,7 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) The operator schema itself is defined in: @@ -47,6 +48,13 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. +The quantized path has two execution strategies: + +- **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. +- **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. + ## Supported Cache Modes ### Non-quantized cache @@ -85,7 +93,11 @@ For INT4, two signed 4-bit values are stored in each byte. The packed head dimen During quantized execution, new key/value vectors are quantized on write into the present cache. Existing past-cache data and newly written present-cache data are then consumed by MLAS quantized GEMM helpers. -## QK GEMM +## Naive Path: QK GEMM + Softmax + SV GEMM + +The naive (full materialization) path executes attention as three separate stages: + +### QK GEMM The QK stage computes: @@ -102,7 +114,7 @@ For quantized K cache, the CPU path calls `MlasQKGemm` with: The default MLAS contract is exact with respect to the FP32 query operand: only the K cache is dequantized on the fly. The query row is not quantized by default. -## Softmax and Masking +### Softmax and Masking After QK GEMM, the CPU path applies the same attention-score processing used by the non-quantized path, including supported combinations of: @@ -115,7 +127,7 @@ After QK GEMM, the CPU path applies the same attention-score processing used by The quantized cache mode does not change these score-processing semantics. -## SV GEMM +### SV GEMM The SV stage computes: @@ -132,6 +144,66 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. +## Flash Attention Path + +The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasQKGemm` on a block slice of quantized K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **Fused SV accumulate** — `MlasSVGemm(..., Beta=1.0)` dequantizes V on the fly and accumulates `softmax(QK_block) × V_block` into the output in a single pass (no intermediate FP32 buffer) +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +### Activation Conditions + +The flash path is selected when ALL of the following hold: + +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `total_sequence_length > 1` +- No softcap +- No smooth softmax +- No head sink +- No output QK capture + +Attention bias is fully supported in the flash path (applied per-tile after QK GEMM). The bias tensor shape `[B|1, N|1, S, T]` supports broadcast along both batch and head dimensions. + +When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Size Selection + +Block sizes are chosen based on L2 cache size: + +- `kv_block_size (Bc)`: Sized so that a full KV block's scores + dequantized V fit within L2. Typical values: 128–256. +- `q_block_size (Br)`: Sized for the query tile. Typical value: 64. + +### Threading + +The flash kernel parallelizes across `(batch, head, q_block)` tiles using the ORT intra-op thread pool. Each thread gets a private working buffer containing space for: + +- `l[Br]` and `m[Br]` — running softmax statistics +- `scores[Br × Bc]` — QK scores for current KV block +- `temp_output[Br × H]` — accumulated output + +The V dequantization temp buffer was eliminated by fusing dequantization into `MlasSVGemm` with `Beta=1.0` (accumulate mode). This reduces per-thread buffer size by `Bc × H × 4` bytes (e.g., 64 KB for Bc=128, H=128). + +### Flash Decoding (Decode Optimization) + +For decode steps (`sequence_length == 1`), the standard `(batch, head, q_block)` partitioning yields only `batch × num_heads` tasks, which can underutilize thread pools on machines with many cores (e.g., 96 threads with batch=1, num_heads=32 produces only 32 tasks). + +When `batch × num_heads < thread_count` and `kv_chunk_count > 1`, the kernel switches to a **flash decoding** strategy that also partitions along the KV sequence dimension: + +- **Phase 1** (parallel over `batch × num_heads × kv_chunk_count` tasks): Each thread computes partial attention for one KV chunk, producing per-chunk `(m, l, S_exp × V)` stored in a partials buffer. +- **Phase 2** (parallel over `batch × num_heads` tasks): Merge partials using log-sum-exp rescaling: `output = Σ_c(exp(m_c − m_global) × partial_c) / Σ_c(exp(m_c − m_global) × l_c)`. + +The partials buffer is allocated alongside the per-thread scratch in a single allocation: +- Per-thread scratch: `scores[Bc]` (one float per KV block element) +- Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. @@ -168,7 +240,7 @@ CPU GroupQueryAttention coverage is split across operator-level and MLAS-level t - `onnxruntime/test/mlas/unittest/test_qkv_quant.cpp` - MLAS `MlasKVQuantize`, `MlasKVDequantize`, `MlasQKGemm`, and `MlasSVGemm` contract tests. -The MLAS benchmark for quantized KV-cache GEMM is: +The MLAS benchmark for quantized KV-cache GEMM and flash attention is: - `onnxruntime/test/mlas/bench/bench_qkv_quant.cpp` @@ -223,6 +295,23 @@ ORT_MLAS_QKGEMM_S8_APPROX_VNNI=1 ./onnxruntime_mlas_benchmark \ --benchmark_report_aggregates_only=true ``` +Run flash vs naive full-attention benchmark: + +```bash +cd build/cpu_test/Release +./onnxruntime_mlas_benchmark \ + --benchmark_filter='BM_GQA_(Naive|Flash)' \ + --benchmark_min_time=0.5s \ + --benchmark_repetitions=3 \ + --benchmark_report_aggregates_only=true +``` + +To force the naive path at the operator level (for A/B testing during inference): + +```bash +ORT_GQA_DISABLE_FLASH_ATTENTION=1 ./your_inference_app +``` + ### Updated benchmark results The following results were measured on an Intel Xeon Platinum 8480C, 96 CPUs, using the CPU Release benchmark binary. Shape: `M=1`, `N=512`, `K=128`, INT8 per-tensor QKGemm. @@ -236,6 +325,110 @@ The following results were measured on an Intel Xeon Platinum 8480C, 96 CPUs, us For comparison, the earlier PR description reported the approximate AVX512 VNNI path at 1,938 ns for this shape, with scalar at 30,179 ns and AVX2 at 4,219 ns. The default AVX512 path is now the exact FP32 fused-dequant implementation, so it is slower than approximate VNNI but preserves the `MlasQKGemm` FP32-query contract. +### Flash Attention vs Naive benchmark results + +Measured on Intel Xeon Platinum 8480C, 96 CPUs. INT8 quantized KV cache, threads=8. + +Two benchmark levels are reported: +- **Operator-level** (`benchmark_gqa_cpu_flash.py`): Measures the full GQA operator via `InferenceSession`, including KV cache concatenation, quantization of new K/V, and Python/C++ boundary overhead. +- **MLAS kernel-level** (`bench_qkv_quant.cpp`): Measures only the attention kernel (QK+softmax+SV), isolating the algorithmic gain from operator overhead. + +```bash +# Operator-level Python benchmark: +cd /tmp +PYTHONPATH=build/cpu/Release python \ + onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py --warmup 5 --repeats 20 + +# MLAS kernel-level C++ benchmark: +cd build/cpu/Release +./onnxruntime_mlas_benchmark \ + --benchmark_filter='BM_GQA_(Naive|Flash)' \ + --benchmark_min_time=0.5s \ + --benchmark_repetitions=3 \ + --benchmark_report_aggregates_only=true +``` + +#### Latency — Prefill (S = T, prompt phase) + +Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128, INT8 per-tensor. + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 7.7 | 8.9 | 0.9x | operator | +| 1024 | 36.8 | 30.2 | 1.2x | operator | +| 2048 | 157.9 | 110.2 | 1.4x | operator | +| 4096 | 790.6 | 427.1 | 1.9x | operator | +| 512 | 9.9 | 8.1 | 1.2x | MLAS kernel | +| 1024 | 44.4 | 27.0 | 1.6x | MLAS kernel | +| 2048 | 190.9 | 116.9 | 1.6x | MLAS kernel | +| 4096 | 1257.8 | 461.6 | 2.7x | MLAS kernel | + +The operator-level naive path is faster than the MLAS-level naive at small S because the naive path's QK GEMM batches all heads in one call, amortizing thread dispatch. At larger S, the flash kernel's O(S×Bc) tiling wins decisively. + +MLAS kernel-level per-channel results: + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 10.7 | 10.8 | 1.0x | MLAS kernel | +| 1024 | 49.5 | 41.7 | 1.2x | MLAS kernel | +| 2048 | 212.1 | 164.1 | 1.3x | MLAS kernel | +| 4096 | 1223.9 | 607.8 | 2.0x | MLAS kernel | + +#### Latency — Decode (S = 1, token generation) + +Shape: B=1, num_heads=16, kv_num_heads=8, head_size=128, INT8 per-tensor. +Flash decoding is NOT active for this config (batch×heads=16 > threads=8). + +| Total Seqlen | Naive | Flash | Speedup | Source | +|---:|---:|---:|---:|:---| +| 513 | 0.133 ms | 0.149 ms | 0.9x | operator | +| 1025 | 0.258 ms | 0.224 ms | 1.2x | operator | +| 2049 | 0.453 ms | 0.394 ms | 1.2x | operator | +| 4097 | 0.681 ms | 0.679 ms | 1.0x | operator | +| 512 | 32 us | 22 us | 1.4x | MLAS kernel | +| 1024 | 71 us | 47 us | 1.5x | MLAS kernel | +| 2048 | 120 us | 87 us | 1.4x | MLAS kernel | +| 4096 | 210 us | 174 us | 1.2x | MLAS kernel | + +At the MLAS kernel level, the flash path is consistently 1.2–1.5x faster for decode due to fused single-pass KV access (better cache locality). At the operator level, the gain is partially masked by KV cache concatenation overhead (~100us), which dominates at short sequences but becomes less significant at longer ones. + +MLAS kernel-level per-channel decode results: + +| Total Seqlen | Naive (us) | Flash (us) | Speedup | Source | +|---:|---:|---:|---:|:---| +| 512 | 53 | 31 | 1.7x | MLAS kernel | +| 1024 | 86 | 52 | 1.7x | MLAS kernel | +| 2048 | 172 | 97 | 1.8x | MLAS kernel | +| 4096 | 299 | 191 | 1.6x | MLAS kernel | + +#### Latency — Flash Decoding (S = 1, KV partitioned across threads) + +Shape: B=1, num_heads=4, kv_num_heads=4 (MHA), head_size=128, threads=8. +Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle threads). + +| Total Seqlen | Naive (us) | Flash (us) | Speedup | Quant | +|---:|---:|---:|---:|:---| +| 512 | 31 | 25 | 1.2x | per-tensor | +| 1024 | 41 | 25 | 1.6x | per-tensor | +| 2048 | 67 | 34 | 2.0x | per-tensor | +| 4096 | 197 | 54 | 3.7x | per-tensor | +| 512 | 25 | 28 | 0.9x | per-channel | +| 1024 | 72 | 27 | 2.7x | per-channel | +| 2048 | 144 | 37 | 3.9x | per-channel | +| 4096 | 304 | 60 | 5.1x | per-channel | + +(Source: MLAS kernel-level benchmark) + +#### Peak Memory — Prefill (S = T, prompt phase) + +| Seq Length | Naive Peak (MB) | Flash Peak (MB) | Memory Reduction | +|---:|---:|---:|---:| +| 2048 | +294 | +44 | 6.7x | +| 4096 | +1107 | +82 | 13.5x | +| 4096 (N=32) | +2131 | +87 | 24.5x | + +**Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. + ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: @@ -246,7 +439,8 @@ The current CPU GroupQueryAttention implementation has a few important limitatio - INT4 cache storage uses packed `uint8` bytes and requires consumers to use the packed head dimension. - The default AVX512 quantized KV-cache GEMM path preserves FP32 query and attention-probability operands; the approximate VNNI QK path is opt-in only. - Hardware dispatch affects performance, but should not change default numeric semantics. -- The MLAS quantized GEMM helpers operate on one per-batch/per-head tile at a time; outer parallelism is managed by the GQA kernel. +- The flash attention path does not support softcap, smooth softmax, head sink, or QK output capture. These features fall back to the naive path. +- The MLAS quantized GEMM helpers operate on one per-batch/per-head tile at a time; outer parallelism is managed by the GQA kernel (or by the flash attention kernel internally). ## Future Work @@ -254,7 +448,6 @@ Further optimization opportunities include: - Improve the exact AVX512 INT8 per-tensor QK path without quantizing the FP32 query, for example by processing multiple K-cache rows per query row while keeping FP32 FMA semantics. - Add AVX512-specific exact micro-kernels for common decode shapes such as `M=1`, `N=512/2048`, and `K=64/128`. -- Add dispatch-specific benchmark coverage for prefill shapes (`M > 1`) and longer cache lengths. - Add dedicated accuracy/performance tests for the approximate VNNI opt-in path before enabling it in any production configuration. - Reduce temporary copies in quantized cache concatenation when past and present buffers cannot be shared directly. - Explore prepacking or layout transforms for long-lived quantized KV caches when the cache update pattern makes that worthwhile. @@ -279,7 +472,14 @@ CPU features that are limited or not implemented relative to the broader operato - quantizes new K/V values into the present cache - concatenates past and present cache chunks when needed - calls `MlasQKGemm` and `MlasSVGemm` +- `GroupQueryAttentionBase::ApplyAttentionQuantizedFlash(...)` + - concatenates new K/V into present cache (parallel over batch × kv_heads) + - invokes `MlasFlashAttentionQuantizedKV` with L2-cache-aware block sizes - `MlasQKGemm(...)` - computes FP32 query times quantized K cache transpose - `MlasSVGemm(...)` - - computes FP32 attention probabilities times quantized V cache + - computes `C = Beta*C + A*dequant(B)` where A is FP32 attention probabilities and B is quantized V cache + - `Beta=0` (overwrite) for naive path; `Beta=1.0` (accumulate) for flash path +- `MlasFlashAttentionQuantizedKV(...)` + - flash attention kernel with online softmax, tiled QK/SV over quantized KV cache + - parallelizes across (batch, head, q_block) tiles via thread pool diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 7a30667befffd..12f61cddea18c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -13,6 +14,9 @@ #include "core/common/safeint.h" #include "core/framework/op_kernel.h" #include "core/mlas/inc/mlas_qkv_quant.h" +#include "core/platform/env.h" +#include "core/platform/env_var_utils.h" +#include "core/platform/threadpool.h" #include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" namespace onnxruntime { @@ -93,6 +97,8 @@ class GQAAttentionBase { kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); kv_quant_enabled_ = (k_quant_type_ != KVQuantizationType::NONE); + disable_gqa_flash_ = ParseEnvironmentVariableWithDefault("ORT_GQA_DISABLE_FLASH_ATTENTION", false); + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); } @@ -111,6 +117,7 @@ class GQAAttentionBase { KVQuantizationType v_quant_type_; int kv_cache_bit_width_; bool kv_quant_enabled_; + bool disable_gqa_flash_; template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH @@ -536,11 +543,363 @@ class GQAAttentionBase { MlasSVGemm(sequence_length, head_size, total_seqlen, attention_probs + probs_offset, seqlen_present_kv_cache, v_quantized, quant_type, head_v_scale, - output_current, hidden_size, nullptr); + output_current, hidden_size, 0.0f, nullptr); + } + }); + } + + return Status::OK(); + } + + // Flash Attention style tiled computation for quantized KV cache. + // Avoids materializing the full [B, N, S, T] attention probability matrix. + // Uses online softmax with KV block tiling for reduced memory usage. + Status ApplyAttentionQuantizedFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (uint8_t) + const Tensor* past_value, // past V (uint8_t) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (uint8_t) + Tensor* present_value, // present V (uint8_t) + const Tensor* seqlens_k, + const float* k_scale, + const float* v_scale, + MLAS_KV_QUANT_TYPE quant_type, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for quantized KV cache"); + + // Access cache data as raw bytes + const uint8_t* past_key_data = nullptr; + uint8_t* present_key_data = nullptr; + const uint8_t* past_value_data = nullptr; + uint8_t* present_value_data = nullptr; + if (kv_cache_bit_width_ == 4) { + past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + present_key_data = present_key->MutableData(); + past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + present_value_data = present_value->MutableData(); + } else { + past_key_data = past_key != nullptr ? reinterpret_cast(past_key->Data()) : nullptr; + present_key_data = reinterpret_cast(present_key->MutableData()); + past_value_data = past_value != nullptr ? reinterpret_cast(past_value->Data()) : nullptr; + present_value_data = reinterpret_cast(present_value->MutableData()); + } + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const bool per_channel = (quant_type == MLAS_KV_QUANT_TYPE::S8_PerChannel || + quant_type == MLAS_KV_QUANT_TYPE::S4_PerChannel); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = kv_sequence_length * head_size; + const size_t past_buff_chunk_bytes = SafeInt(seqlen_past_kv_cache) * packed_row_bytes; + const size_t present_buff_chunk_bytes = SafeInt(seqlen_present_kv_cache) * packed_row_bytes; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_bytes); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_bytes); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast(past_buff_chunk_bytes + kv_sequence_length * head_size * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_bytes); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_bytes = past_seqlen * packed_row_bytes; + + const float* head_k_scale = per_channel + ? k_scale + kv_head_index * head_size + : k_scale; + const float* head_v_scale = per_channel + ? v_scale + kv_head_index * head_size + : v_scale; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatQuantStateChunkGQA( + past_key_data, k_new, present_key_data, + present_buff_chunk_bytes, past_buff_chunk_bytes, + past_chunk_bytes, kv_sequence_length, head_size, head_size, + quant_type, head_k_scale, past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatQuantStateChunkGQA( + past_value_data, v_new, present_value_data, + present_buff_chunk_bytes, past_buff_chunk_bytes, + past_chunk_bytes, kv_sequence_length, head_size, head_size, + quant_type, head_v_scale, past_present_share_buffer, kv_idx); } }); } + // ---- Phase 2: Flash Attention with quantized KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention) + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + // For quantized KV: effective bytes per KV element for cache considerations + // We dequantize V blocks to FP32, so working set per KV row = head_size * sizeof(float) + // K is accessed via MlasQKGemm which internally dequantizes; for block sizing purposes + // treat it as FP32 working set. + // + // Working set in L2 per tile: + // Q slice: [Br, head_size] floats + // Scores: [Br, Bc] floats + // V dequant: [Bc, head_size] floats + // Temp output: [Br, head_size] floats + // Total ~ (2*Br + Bc) * head_size + Br * Bc + // Approximation: use same formula as FP32 flash attention + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), we must fall back to per-batch + // invocation so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + // Ragged seqlens: each batch item has its own total_seqlen (and therefore + // past_seqlen). Must use per-batch invocation regardless of past_key/prompt state. + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + // Allocate per-thread buffers for flash attention + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = static_cast(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = buffer_size_per_thread * thread_count + partials_buffer_bytes; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + args.quant_type = quant_type; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionQuantizedKV(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + // For prompt/no-past cases, past_seqlen is 0; otherwise derive from total_sl. + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + args.quant_type = quant_type; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + args.query = Q + static_cast(b) * num_heads_ * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * seqlen_present_kv_cache * packed_row_bytes; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * seqlen_present_kv_cache * packed_row_bytes; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside) + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + batch_bias += static_cast(b) * num_heads_ * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; + + MlasFlashAttentionQuantizedKV(&args, tp); + } + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 4df5f6a349599..1b9e4c3a6a5cd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -294,13 +294,34 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if constexpr (std::is_same_v) { const float* k_data_q = packed_qkv ? nullptr : k_rotary; const float* v_data_q = packed_qkv ? nullptr : V.Get().Data(); + auto mlas_quant_type = ToMlasKVQuantType(k_quant_type_, kv_cache_bit_width_); + + // Use flash attention path when: + // 1. Total sequence length is long enough to benefit from tiling + // 2. No features that flash path doesn't support (softcap, smooth softmax, output_qk) + const bool use_flash = !disable_gqa_flash_ && + parameters.total_sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr; + + if (use_flash) { + return ApplyAttentionQuantizedFlash( + q_rotary, k_data_q, v_data_q, + attention_bias, + past_key, past_value, + output, present_k, present_v, seqlens_k, + k_scale->Data(), v_scale->Data(), + mlas_quant_type, parameters, allocator, context); + } + return ApplyAttentionQuantized( q_rotary, k_data_q, v_data_q, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, k_scale->Data(), v_scale->Data(), - ToMlasKVQuantType(k_quant_type_, kv_cache_bit_width_), - parameters, allocator, context); + mlas_quant_type, parameters, allocator, context); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Quantized KV cache requires float Q dtype"); diff --git a/onnxruntime/core/mlas/inc/mlas_qkv_quant.h b/onnxruntime/core/mlas/inc/mlas_qkv_quant.h index ed1a9bfdbfba6..f6a5a48e6ccc7 100644 --- a/onnxruntime/core/mlas/inc/mlas_qkv_quant.h +++ b/onnxruntime/core/mlas/inc/mlas_qkv_quant.h @@ -199,7 +199,7 @@ MlasQKGemm( /** * @brief Softmax-times-V GEMM with a quantized V cache. * - * C[M, N] = A[M, K] * B[K, N] + * C[M, N] = Beta * C[M, N] + A[M, K] * B[K, N] * * where: * - A is FP32 row-major, shape [M, K] (attention probabilities), stride lda. @@ -207,8 +207,8 @@ MlasQKGemm( * with K = total_sequence_length, N = head_size), packed row-major over * rows. Each row occupies * MlasKVQuantPackedRowBytes(QuantType, N) bytes. - * - C is FP32 row-major, shape [M, N], stride ldc (>= N). The kernel - * overwrites C (no accumulate). + * - C is FP32 row-major, shape [M, N], stride ldc (>= N). + * When Beta == 0, C is overwritten. When Beta != 0, C is accumulated. * - PER_CHANNEL scales are length N and apply along the N (head_size) axis. * * @param M Query token count. @@ -221,6 +221,7 @@ MlasQKGemm( * @param Scales Scale buffer (single scalar or length-N vector). * @param C Output buffer (FP32). * @param ldc Leading dimension of C in elements. + * @param Beta Scalar multiplier for existing C values. 0 = overwrite. * @param ThreadPool Optional thread pool. */ void @@ -236,5 +237,73 @@ MlasSVGemm( const float* Scales, float* C, size_t ldc, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief Arguments for the Flash Attention kernel with quantized KV cache. + * + * This kernel implements the online-softmax tiled Flash Attention algorithm + * operating directly on INT8/INT4 quantized K and V cache buffers. + * It avoids materializing the full [S, T] attention probability matrix. + */ +struct MlasFlashAttentionQuantizedKVArgs { + int batch_size; + int num_heads; // Q heads + int kv_num_heads; // KV heads (for GQA sharing) + int sequence_length; // Q sequence length (new tokens) + int total_seqlen; // Total KV sequence length (past + new) + int head_size; + int past_seqlen; // For computing causal positions + int local_window_size; // -1 = disabled + int seqlen_present_kv; // Buffer dimension for present KV (may be > total_seqlen) + int q_block_size; // Br (query block size) + int kv_block_size; // Bc (KV block size) + float scale; // 1/sqrt(head_size) or user-specified + + MLAS_KV_QUANT_TYPE quant_type; + bool per_channel_k; // Whether K uses per-channel scales + bool per_channel_v; // Whether V uses per-channel scales + + int thread_count; + float* buffer; + size_t buffer_size_per_thread; + + const float* query; // [B, N, S, H] FP32 + const uint8_t* k_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized + const uint8_t* v_cache; // [B, kv_N, seqlen_present, packed_row_bytes] quantized + const float* k_scale; // Scalar or per-channel scales for K + const float* v_scale; // Scalar or per-channel scales for V + float* output; // [B, S, N, H] FP32 + + // Attention bias (additive, applied after QK GEMM before masking/softmax). + // Shape: [B|1, N|1, S, T] where dimensions of size 1 are broadcast. + const float* attention_bias; // nullptr if no bias + int attention_bias_seqlen_stride; // stride along the T (total_seqlen) dimension = shape[3] + bool attention_bias_broadcast_batch; // true if shape[0] == 1 + bool attention_bias_broadcast_head; // true if shape[1] == 1 + + // Flash decoding fields (used when sequence_length == 1 and KV is split across threads). + // Partials buffer stores per-(batch, head, kv_chunk) intermediate results: + // [m_partial, l_partial, output_partial[head_size]] for each chunk. + float* flash_decoding_partials; // nullptr to disable flash decoding + int kv_chunk_count; // number of KV chunks = ceil(total_seqlen / kv_block_size) +}; + +/** + * @brief Flash Attention with quantized KV cache. + * + * Implements tiled attention with online softmax, processing KV in blocks + * to avoid materializing the full attention matrix. Supports causal masking + * and local window attention. + * + * @param args Pointer to argument structure. + * @param ThreadPool Optional thread pool for parallelization. + */ +void +MLASCALL +MlasFlashAttentionQuantizedKV( + MlasFlashAttentionQuantizedKVArgs* args, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/flashattn_qkv.cpp b/onnxruntime/core/mlas/lib/flashattn_qkv.cpp new file mode 100644 index 0000000000000..364011fe26e26 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_qkv.cpp @@ -0,0 +1,622 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_qkv.cpp + +Abstract: + + Flash Attention kernel for quantized KV cache (INT8/INT4). + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate + on quantized K/V buffers using MlasQKGemm (for Q×K^T) and + MlasSVGemm with Beta=1.0 (for fused dequant + S×V accumulation). + + Supports causal masking and local window attention. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" +#include "mlas_qkv_quant.h" + +void +MlasFlashAttentionQuantizedKVThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + const MLAS_KV_QUANT_TYPE quant_type = args->quant_type; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, static_cast(head_size)); + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // Pointers into quantized K/V caches + // K cache layout: [batch, kv_num_heads, seqlen_present, packed_head_bytes] + const size_t k_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* k_cache_head = args->k_cache + k_batch_head_offset; + + const size_t v_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* v_cache_head = args->v_cache + v_batch_head_offset; + + // K/V scale pointers + const float* head_k_scale = args->per_channel_k + ? args->k_scale + kv_head_idx * static_cast(head_size) + : args->k_scale; + const float* head_v_scale = args->per_channel_v + ? args->v_scale + kv_head_idx * static_cast(head_size) + : args->v_scale; + + // Q pointer: layout [batch, num_heads, seq, head_size] or packed + const float* q_ptr = args->query + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with quantized K block + // K cache at row offset ir: pointer arithmetic on packed rows + const uint8_t* k_block = k_cache_head + static_cast(ir) * packed_row_bytes; + + MlasQKGemm( + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // Alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (quantized K block) + quant_type, + head_k_scale, + scores, // C (output scores) + row_size_kv, // ldc + nullptr // no thread pool (already threaded) + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + static_cast(num_heads) * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so SVGemm contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block using fused dequant+GEMM + const uint8_t* v_block = v_cache_head + static_cast(ir) * packed_row_bytes; + MlasSVGemm( + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (quantized V block) + quant_type, + head_v_scale, + temp_output, // C (accumulated output) + static_cast(head_size), // ldc + 1.0f, // Beta (accumulate) + nullptr // no thread pool (already threaded) + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingQuantizedKVThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + const MLAS_KV_QUANT_TYPE quant_type = args->quant_type; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, static_cast(head_size)); + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t k_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* k_cache_head = args->k_cache + k_batch_head_offset; + + const size_t v_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + static_cast(args->seqlen_present_kv) * packed_row_bytes; + const uint8_t* v_cache_head = args->v_cache + v_batch_head_offset; + + // K/V scale pointers + const float* head_k_scale = args->per_channel_k + ? args->k_scale + kv_head_idx * static_cast(head_size) + : args->k_scale; + const float* head_v_scale = args->per_channel_v + ? args->v_scale + kv_head_idx * static_cast(head_size) + : args->v_scale; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1) + const float* q_ptr = args->query + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + // Step 1: QK^T GEMM for this KV chunk + const uint8_t* k_block = k_cache_head + static_cast(ir) * packed_row_bytes; + MlasQKGemm( + 1, // M (single query row) + row_size_kv, // N + static_cast(head_size), // K + scale, // Alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (quantized K block) + quant_type, + head_k_scale, + scores, // C (output scores) + row_size_kv, // ldc + nullptr + ); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + static_cast(num_heads) * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output + const uint8_t* v_block = v_cache_head + static_cast(ir) * packed_row_bytes; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + MlasSVGemm( + 1, // M + static_cast(head_size), // N + row_size_kv, // K + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (quantized V block) + quant_type, + head_v_scale, + partial_output, // C (output for this chunk) + static_cast(head_size), // ldc + 0.0f, // Beta=0 (overwrite) + nullptr + ); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionQuantizedKVArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + // Use the output location directly for accumulation + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +void +MLASCALL +MlasFlashAttentionQuantizedKV( + MlasFlashAttentionQuantizedKVArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + if (args->flash_decoding_partials != nullptr && args->sequence_length == 1) { + // Flash decoding: two-phase approach. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingQuantizedKVThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + MlasExecuteThreaded( + MlasFlashAttentionQuantizedKVThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } +} diff --git a/onnxruntime/core/mlas/lib/qkv_quant.cpp b/onnxruntime/core/mlas/lib/qkv_quant.cpp index 81fba6bc7cec4..c414324a0493f 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant.cpp @@ -356,14 +356,23 @@ MlasSVGemm( const float* Scales, float* C, size_t ldc, + float Beta, MLAS_THREADPOOL* ThreadPool) { if (M == 0 || N == 0) { return; } if (K == 0) { - for (size_t m = 0; m < M; ++m) { - std::memset(C + m * ldc, 0, N * sizeof(float)); + if (Beta == 0.0f) { + for (size_t m = 0; m < M; ++m) { + std::memset(C + m * ldc, 0, N * sizeof(float)); + } + } else if (Beta != 1.0f) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + C[m * ldc + n] *= Beta; + } + } } return; } @@ -373,7 +382,7 @@ MlasSVGemm( // const auto* Dispatch = GetMlasPlatform().KVQuantGemmDispatch; if (Dispatch != nullptr && Dispatch->SVGemm != nullptr) { - Dispatch->SVGemm(M, N, K, A, lda, B, QuantType, Scales, C, ldc); + Dispatch->SVGemm(M, N, K, A, lda, B, QuantType, Scales, C, ldc, Beta); return; } @@ -393,7 +402,13 @@ MlasSVGemm( const size_t m = static_cast(m_idx); const float* a_row = A + m * lda; float* c_row = C + m * ldc; - std::memset(c_row, 0, N * sizeof(float)); + if (Beta == 0.0f) { + std::memset(c_row, 0, N * sizeof(float)); + } else if (Beta != 1.0f) { + for (size_t n = 0; n < N; ++n) { + c_row[n] *= Beta; + } + } // Per-row scratch for one dequantized B row of length N. float b_dequant[1024]; diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel.h b/onnxruntime/core/mlas/lib/qkv_quant_kernel.h index 5c4e93bb334c3..ebd990703472d 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel.h +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel.h @@ -53,7 +53,7 @@ struct MLAS_KV_QUANT_GEMM_DISPATCH { QKGemm_Fn* QKGemm = nullptr; /** - * S*V GEMM kernel: C[M,N] = A[M,K] * B[K,N] + * S*V GEMM kernel: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] * * B is quantized (INT8 or INT4), logically [K, N] in packed row-major. */ @@ -67,7 +67,8 @@ struct MLAS_KV_QUANT_GEMM_DISPATCH { MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc + size_t ldc, + float Beta ); SVGemm_Fn* SVGemm = nullptr; diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp index 8bec2d350afa5..d7bb01deec2ed 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp @@ -268,7 +268,7 @@ QKGemm_Avx2( } // -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // B is [K,N] packed row-major. // // Fused approach: dequantize each B[k,:] element directly into the FMA with @@ -285,7 +285,8 @@ SVGemm_Avx2( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -298,13 +299,26 @@ SVGemm_Avx2( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output - size_t n = 0; - for (; n < vec_end_n; n += 8) { - _mm256_storeu_ps(c_row + n, _mm256_setzero_ps()); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 8) { + _mm256_storeu_ps(c_row + n, _mm256_setzero_ps()); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + __m256 beta_vec = _mm256_broadcast_ss(&Beta); + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_mul_ps(c_vec, beta_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } if (!int4) { @@ -315,7 +329,7 @@ SVGemm_Avx2( const float a_val = a_row[k]; __m256 a_broadcast = _mm256_broadcast_ss(&a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 8) { __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); __m256i i32 = _mm256_cvtepi8_epi32(raw); @@ -331,35 +345,59 @@ SVGemm_Avx2( } } } else { - // Per-tensor: accumulate unscaled dot products, then scale the output row once. - for (size_t k = 0; k < K; ++k) { - const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); - const float a_val = a_row[k]; - __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + // Per-tensor: when Beta==0, accumulate unscaled then scale once at end. + // When Beta!=0, C already has scaled values so fold scale into a_val. + if (Beta == 0.0f) { + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k]; + __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); + __m256i i32 = _mm256_cvtepi8_epi32(raw); + __m256 bf = _mm256_cvtepi32_ps(i32); + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } + } - n = 0; + __m256 scale_vec = _mm256_broadcast_ss(Scales); + size_t n = 0; for (; n < vec_end_n; n += 8) { - __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); - __m256i i32 = _mm256_cvtepi8_epi32(raw); - __m256 bf = _mm256_cvtepi32_ps(i32); __m256 c_vec = _mm256_loadu_ps(c_row + n); - c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + c_vec = _mm256_mul_ps(c_vec, scale_vec); _mm256_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]); + c_row[n] *= Scales[0]; + } + } else { + // Beta!=0: fold scale into a_val to avoid separate pass + const float tensor_scale = Scales[0]; + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k] * tensor_scale; + __m256 a_broadcast = _mm256_broadcast_ss(&a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 8) { + __m128i raw = _mm_loadl_epi64(reinterpret_cast(b_row + n)); + __m256i i32 = _mm256_cvtepi8_epi32(raw); + __m256 bf = _mm256_cvtepi32_ps(i32); + __m256 c_vec = _mm256_loadu_ps(c_row + n); + c_vec = _mm256_fmadd_ps(a_broadcast, bf, c_vec); + _mm256_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } } - } - - __m256 scale_vec = _mm256_broadcast_ss(Scales); - n = 0; - for (; n < vec_end_n; n += 8) { - __m256 c_vec = _mm256_loadu_ps(c_row + n); - c_vec = _mm256_mul_ps(c_vec, scale_vec); - _mm256_storeu_ps(c_row + n, c_vec); - } - for (; n < N; ++n) { - c_row[n] *= Scales[0]; } } } else { @@ -369,7 +407,7 @@ SVGemm_Avx2( const float a_val = a_row[k]; __m256 a_broadcast = _mm256_broadcast_ss(&a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 8) { __m256 bf = DequantInt4x8(b_row, n, per_channel, Scales); __m256 c_vec = _mm256_loadu_ps(c_row + n); diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp index fa5aff0165897..16e82f19c3711 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp @@ -512,7 +512,7 @@ QKGemm_Avx512Vnni( } // ============================================================================ -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // B is [K,N] packed row-major. // // For SVGemm, A is attention weights (FP32) and B is V-cache (quantized). @@ -532,7 +532,8 @@ SVGemm_Avx512Vnni( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -545,13 +546,26 @@ SVGemm_Avx512Vnni( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output using 512-bit stores - size_t n = 0; - for (; n < vec_end_n; n += 16) { - _mm512_storeu_ps(c_row + n, _mm512_setzero_ps()); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 16) { + _mm512_storeu_ps(c_row + n, _mm512_setzero_ps()); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + __m512 beta_vec = _mm512_set1_ps(Beta); + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_mul_ps(c_vec, beta_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } if (!int4) { @@ -562,7 +576,7 @@ SVGemm_Avx512Vnni( const float a_val = a_row[k]; __m512 a_broadcast = _mm512_set1_ps(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 16) { __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); __m512i i32 = _mm512_cvtepi8_epi32(raw); @@ -578,35 +592,59 @@ SVGemm_Avx512Vnni( } } } else { - // Per-tensor: accumulate unscaled dot products, then scale the output row once. - for (size_t k = 0; k < K; ++k) { - const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); - const float a_val = a_row[k]; - __m512 a_broadcast = _mm512_set1_ps(a_val); + // Per-tensor: when Beta==0, accumulate unscaled then scale once at end. + // When Beta!=0, fold scale into a_val. + if (Beta == 0.0f) { + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k]; + __m512 a_broadcast = _mm512_set1_ps(a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); + __m512i i32 = _mm512_cvtepi8_epi32(raw); + __m512 bf = _mm512_cvtepi32_ps(i32); + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } + } - n = 0; + __m512 scale_vec = _mm512_set1_ps(Scales[0]); + size_t n = 0; for (; n < vec_end_n; n += 16) { - __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); - __m512i i32 = _mm512_cvtepi8_epi32(raw); - __m512 bf = _mm512_cvtepi32_ps(i32); __m512 c_vec = _mm512_loadu_ps(c_row + n); - c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + c_vec = _mm512_mul_ps(c_vec, scale_vec); _mm512_storeu_ps(c_row + n, c_vec); } for (; n < N; ++n) { - c_row[n] += a_val * static_cast(b_row[n]); + c_row[n] *= Scales[0]; + } + } else { + // Beta!=0: fold scale into a_val + const float tensor_scale = Scales[0]; + for (size_t k = 0; k < K; ++k) { + const int8_t* b_row = reinterpret_cast(B_bytes + k * row_bytes); + const float a_val = a_row[k] * tensor_scale; + __m512 a_broadcast = _mm512_set1_ps(a_val); + + size_t n = 0; + for (; n < vec_end_n; n += 16) { + __m128i raw = _mm_loadu_si128(reinterpret_cast(b_row + n)); + __m512i i32 = _mm512_cvtepi8_epi32(raw); + __m512 bf = _mm512_cvtepi32_ps(i32); + __m512 c_vec = _mm512_loadu_ps(c_row + n); + c_vec = _mm512_fmadd_ps(a_broadcast, bf, c_vec); + _mm512_storeu_ps(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] += a_val * static_cast(b_row[n]); + } } - } - - __m512 scale_vec = _mm512_set1_ps(Scales[0]); - n = 0; - for (; n < vec_end_n; n += 16) { - __m512 c_vec = _mm512_loadu_ps(c_row + n); - c_vec = _mm512_mul_ps(c_vec, scale_vec); - _mm512_storeu_ps(c_row + n, c_vec); - } - for (; n < N; ++n) { - c_row[n] *= Scales[0]; } } } else { @@ -616,7 +654,7 @@ SVGemm_Avx512Vnni( const float a_val = a_row[k]; __m512 a_broadcast = _mm512_set1_ps(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 16) { __m512 bf = DequantInt4x16_Avx512(b_row, n, per_channel, Scales); __m512 c_vec = _mm512_loadu_ps(c_row + n); diff --git a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp index 1aabbd8ca39cb..070b1243955cd 100644 --- a/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp @@ -244,7 +244,7 @@ QKGemm_Neon( } // -// SVGemm: C[M,N] = A[M,K] * B[K,N] +// SVGemm: C[M,N] = Beta * C[M,N] + A[M,K] * B[K,N] // void SVGemm_Neon( @@ -257,7 +257,8 @@ SVGemm_Neon( MLAS_KV_QUANT_TYPE QuantType, const float* Scales, float* C, - size_t ldc) + size_t ldc, + float Beta) { const size_t row_bytes = MlasKVQuantPackedRowBytes(QuantType, N); const auto* B_bytes = static_cast(B); @@ -277,23 +278,40 @@ SVGemm_Neon( float* c_row = C + m * ldc; const float* a_row = A + m * lda; - // Zero output - size_t n = 0; - for (; n < vec_end_n; n += 4) { - vst1q_f32(c_row + n, vdupq_n_f32(0.0f)); - } - for (; n < N; ++n) { - c_row[n] = 0.0f; + // Initialize output + if (Beta == 0.0f) { + size_t n = 0; + for (; n < vec_end_n; n += 4) { + vst1q_f32(c_row + n, vdupq_n_f32(0.0f)); + } + for (; n < N; ++n) { + c_row[n] = 0.0f; + } + } else if (Beta != 1.0f) { + float32x4_t beta_vec = vdupq_n_f32(Beta); + size_t n = 0; + for (; n < vec_end_n; n += 4) { + float32x4_t c_vec = vld1q_f32(c_row + n); + c_vec = vmulq_f32(c_vec, beta_vec); + vst1q_f32(c_row + n, c_vec); + } + for (; n < N; ++n) { + c_row[n] *= Beta; + } } + // When Beta != 0 and per-tensor, we must apply scale inline during + // dequantization (can't defer scaling since C already has scaled values). + const bool apply_scale_inline = per_channel || (Beta != 0.0f); + for (size_t k = 0; k < K; ++k) { const uint8_t* b_row_packed = B_bytes + k * row_bytes; - DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, per_channel); + DequantRow_Neon(b_row_packed, b_buf, N, QuantType, Scales, apply_scale_inline); const float a_val = a_row[k]; float32x4_t a_broadcast = vdupq_n_f32(a_val); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 4) { float32x4_t c_vec = vld1q_f32(c_row + n); float32x4_t b_vec = vld1q_f32(b_buf + n); @@ -305,9 +323,9 @@ SVGemm_Neon( } } - if (!per_channel) { + if (!apply_scale_inline) { const float32x4_t scale_vec = vdupq_n_f32(Scales[0]); - n = 0; + size_t n = 0; for (; n < vec_end_n; n += 4) { float32x4_t c_vec = vld1q_f32(c_row + n); c_vec = vmulq_f32(c_vec, scale_vec); diff --git a/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp b/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp index 63b6a3eb212d0..23ca591ba6ed2 100644 --- a/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp +++ b/onnxruntime/test/mlas/bench/bench_qkv_quant.cpp @@ -13,6 +13,7 @@ #include "mlas_qkv_quant.h" #include "core/mlas/lib/mlasi.h" #include "core/mlas/lib/qkv_quant_kernel.h" +#include "core/util/thread_utils.h" #include "benchmark/benchmark.h" #include "bench_util.h" @@ -127,10 +128,10 @@ static void BM_SVGemm(benchmark::State& state) { std::vector C(M * N, 0.0f); // Warmup - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } state.SetItemsProcessed(static_cast(state.iterations()) * M * N * K * 2); @@ -225,10 +226,10 @@ static void BM_SVGemm_Scalar(benchmark::State& state) { auto* saved_dispatch = platform.KVQuantGemmDispatch; platform.KVQuantGemmDispatch = nullptr; - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } platform.KVQuantGemmDispatch = saved_dispatch; @@ -305,10 +306,10 @@ static void BM_SVGemm_Avx2(benchmark::State& state) { auto* saved_dispatch = platform.KVQuantGemmDispatch; platform.KVQuantGemmDispatch = &MlasKVQuantGemmDispatchAvx2; - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); for (auto _ : state) { - MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, nullptr); + MlasSVGemm(M, N, K, A.data(), K, B_quant.data(), qt, scales.data(), C.data(), N, 0.0f, nullptr); } platform.KVQuantGemmDispatch = saved_dispatch; @@ -322,3 +323,284 @@ BENCHMARK(BM_QKGemm_Avx2)->Apply(ScalarArgs)->UseRealTime(); BENCHMARK(BM_SVGemm_Avx2)->Apply(ScalarArgs)->UseRealTime(); #endif // MLAS_TARGET_AMD64 || MLAS_TARGET_IX86 + +// +// Flash Attention vs Naive (full materialization) benchmark. +// Compares MlasFlashAttentionQuantizedKV against the manual +// QKGemm + softmax + SVGemm pipeline for realistic GQA shapes. +// +// Args: batch_size, num_heads, kv_num_heads, seq_len, total_seqlen, head_size, QuantType +// + +static MLAS_THREADPOOL* GetBenchThreadPool() { + static OrtThreadPoolParams tpo; + static bool init = [&]() { + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + return true; + }(); + (void)init; + static std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + return tp.get(); +} + +// Naive path: QKGemm + row-wise softmax + SVGemm (full attention matrix materialized) +static void BM_GQA_Naive(benchmark::State& state) { + const int batch_size = static_cast(state.range(0)); + const int num_heads = static_cast(state.range(1)); + const int kv_num_heads = static_cast(state.range(2)); + const int seq_len = static_cast(state.range(3)); + const int total_seqlen = static_cast(state.range(4)); + const int head_size = static_cast(state.range(5)); + const auto qt = static_cast(state.range(6)); + + const int groups = num_heads / kv_num_heads; + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + + // Allocate query [B, N, S, H] + auto query = RandomFloats(static_cast(batch_size) * num_heads * seq_len * head_size, 42); + + // Allocate and quantize K cache [B, kv_N, T, H] + auto k_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 123); + auto v_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 456); + + size_t k_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t v_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t k_cache_size = static_cast(batch_size) * kv_num_heads * total_seqlen * k_row_bytes; + size_t v_cache_size = static_cast(batch_size) * kv_num_heads * total_seqlen * v_row_bytes; + + std::vector k_cache(k_cache_size); + std::vector v_cache(v_cache_size); + + bool per_channel = (qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel); + size_t num_scales = per_channel ? static_cast(kv_num_heads * head_size) : 1; + std::vector k_scale(num_scales, 0.01f); + std::vector v_scale(num_scales, 0.01f); + + // Quantize K and V caches per kv-head + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < kv_num_heads; ++h) { + size_t offset_fp = (static_cast(b) * kv_num_heads + h) * total_seqlen * head_size; + size_t offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * k_row_bytes; + MlasKVQuantize(k_fp.data() + offset_fp, k_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? k_scale.data() + h * head_size : k_scale.data(), nullptr); + offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * v_row_bytes; + MlasKVQuantize(v_fp.data() + offset_fp, v_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? v_scale.data() + h * head_size : v_scale.data(), nullptr); + } + } + + // Allocate working buffers: scores[B*N, S, T] (one per head) + output[B, S, N, H] + std::vector scores(static_cast(batch_size) * num_heads * seq_len * total_seqlen); + std::vector output(static_cast(batch_size) * seq_len * num_heads * head_size, 0.0f); + + auto* tp = GetBenchThreadPool(); + const ptrdiff_t loop_len = batch_size * num_heads; + + for (auto _ : state) { + // Pass 1: QK GEMM + Softmax (matches operator's first TryParallelFor) + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, loop_len, [&](std::ptrdiff_t i) { + const int b = static_cast(i) / num_heads; + const int h = static_cast(i) % num_heads; + const int kv_h = h / groups; + float* my_scores = scores.data() + static_cast(i) * seq_len * total_seqlen; + const float* q_ptr = query.data() + (static_cast(b) * num_heads + h) * seq_len * head_size; + const uint8_t* k_ptr = k_cache.data() + (static_cast(b) * kv_num_heads + kv_h) * total_seqlen * k_row_bytes; + + // QK GEMM: scores[S, T] = scale * Q[S,H] * K[T,H]^T + MlasQKGemm(seq_len, total_seqlen, head_size, scale, + q_ptr, head_size, k_ptr, qt, + per_channel ? k_scale.data() + kv_h * head_size : k_scale.data(), + my_scores, total_seqlen, nullptr); + + // Causal masking + MLAS-optimized softmax (matches operator) + for (int s = 0; s < seq_len; ++s) { + float* row = my_scores + s * total_seqlen; + int valid_len = total_seqlen - seq_len + s + 1; + // Zero out future positions (operator sets them to 0 before softmax) + for (int t = valid_len; t < total_seqlen; ++t) row[t] = 0.f; + // Use MLAS optimized softmax on valid range only + MlasComputeSoftmax(row, row, static_cast(1), + static_cast(valid_len), false, false, 0.0f, nullptr); + } + }); + + // Pass 2: SV GEMM (matches operator's second TryParallelFor) + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, loop_len, [&](std::ptrdiff_t i) { + const int b = static_cast(i) / num_heads; + const int h = static_cast(i) % num_heads; + const int kv_h = h / groups; + float* my_scores = scores.data() + static_cast(i) * seq_len * total_seqlen; + const uint8_t* v_ptr = v_cache.data() + (static_cast(b) * kv_num_heads + kv_h) * total_seqlen * v_row_bytes; + float* out_ptr = output.data() + (static_cast(b) * seq_len * num_heads + h) * head_size; + + // SV GEMM: out[S, H] = scores[S,T] * V[T,H] + MlasSVGemm(seq_len, head_size, total_seqlen, + my_scores, total_seqlen, v_ptr, qt, + per_channel ? v_scale.data() + kv_h * head_size : v_scale.data(), + out_ptr, num_heads * head_size, 0.0f, nullptr); + }); + benchmark::DoNotOptimize(output.data()); + } + + int64_t flops = static_cast(batch_size) * num_heads * seq_len * + (2LL * total_seqlen * head_size + 2LL * total_seqlen * head_size); + state.SetItemsProcessed(static_cast(state.iterations()) * flops); +} + +// Flash path: MlasFlashAttentionQuantizedKV (tiled, online softmax) +static void BM_GQA_Flash(benchmark::State& state) { + const int batch_size = static_cast(state.range(0)); + const int num_heads = static_cast(state.range(1)); + const int kv_num_heads = static_cast(state.range(2)); + const int seq_len = static_cast(state.range(3)); + const int total_seqlen = static_cast(state.range(4)); + const int head_size = static_cast(state.range(5)); + const auto qt = static_cast(state.range(6)); + + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + bool per_channel = (qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel); + + // Allocate query [B, N, S, H] in BNSH layout + auto query = RandomFloats(static_cast(batch_size) * num_heads * seq_len * head_size, 42); + + // Allocate and quantize K/V caches + auto k_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 123); + auto v_fp = RandomFloats(static_cast(batch_size) * kv_num_heads * total_seqlen * head_size, 456); + + size_t k_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + size_t v_row_bytes = MlasKVQuantPackedRowBytes(qt, head_size); + std::vector k_cache(static_cast(batch_size) * kv_num_heads * total_seqlen * k_row_bytes); + std::vector v_cache(static_cast(batch_size) * kv_num_heads * total_seqlen * v_row_bytes); + + size_t num_scales = per_channel ? static_cast(kv_num_heads * head_size) : 1; + std::vector k_scale(num_scales, 0.01f); + std::vector v_scale(num_scales, 0.01f); + + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < kv_num_heads; ++h) { + size_t offset_fp = (static_cast(b) * kv_num_heads + h) * total_seqlen * head_size; + size_t offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * k_row_bytes; + MlasKVQuantize(k_fp.data() + offset_fp, k_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? k_scale.data() + h * head_size : k_scale.data(), nullptr); + offset_q = (static_cast(b) * kv_num_heads + h) * total_seqlen * v_row_bytes; + MlasKVQuantize(v_fp.data() + offset_fp, v_cache.data() + offset_q, + total_seqlen, head_size, head_size, qt, + per_channel ? v_scale.data() + h * head_size : v_scale.data(), nullptr); + } + } + + // Output [B, S, N, H] + std::vector output(static_cast(batch_size) * seq_len * num_heads * head_size, 0.0f); + + // Fixed block sizes for reproducible benchmarks (operator computes from L2 cache size) + int q_block_size = 64; + int kv_block_size = 256; + + // Thread pool + auto* tp = GetBenchThreadPool(); + int thread_count = 8; + + // Flash decoding: for decode (seq_len=1), partition KV across threads + int kv_chunk_count = (total_seqlen + kv_block_size - 1) / kv_block_size; + bool use_flash_decoding = (seq_len == 1 && + batch_size * num_heads < thread_count && + kv_chunk_count > 1); + + // Working buffer + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + partials_buffer_bytes = static_cast(batch_size) * num_heads * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (static_cast(q_block_size) * 2 + // l + m + static_cast(q_block_size) * static_cast(kv_block_size) + // scores + static_cast(q_block_size) * static_cast(head_size)) * // temp_output + sizeof(float); + } + size_t total_buffer_floats = (buffer_size_per_thread * thread_count + partials_buffer_bytes) / sizeof(float); + std::vector buffer(total_buffer_floats); + float* partials_ptr = use_flash_decoding + ? buffer.data() + (buffer_size_per_thread * thread_count) / sizeof(float) + : nullptr; + + MlasFlashAttentionQuantizedKVArgs args{}; + args.batch_size = batch_size; + args.num_heads = num_heads; + args.kv_num_heads = kv_num_heads; + args.sequence_length = seq_len; + args.total_seqlen = total_seqlen; + args.head_size = head_size; + args.past_seqlen = total_seqlen - seq_len; + args.local_window_size = -1; + args.seqlen_present_kv = total_seqlen; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = qt; + args.per_channel_k = per_channel; + args.per_channel_v = per_channel; + args.thread_count = thread_count; + args.buffer = buffer.data(); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = query.data(); + args.k_cache = k_cache.data(); + args.v_cache = v_cache.data(); + args.k_scale = k_scale.data(); + args.v_scale = v_scale.data(); + args.output = output.data(); + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + // Warmup + MlasFlashAttentionQuantizedKV(&args, tp); + + for (auto _ : state) { + MlasFlashAttentionQuantizedKV(&args, tp); + benchmark::DoNotOptimize(output.data()); + } + + int64_t flops = static_cast(batch_size) * num_heads * seq_len * + (2LL * total_seqlen * head_size + 2LL * total_seqlen * head_size); + state.SetItemsProcessed(static_cast(state.iterations()) * flops); +} + +// Flash vs Naive benchmark configurations +// Args: batch, num_heads, kv_num_heads, seq_len, total_seqlen, head_size, QuantType +static void FlashGQAArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"B", "N", "N_kv", "S", "T", "H", "QType"}); + // INT8 per-tensor (qt=0), INT8 per-channel (qt=1) + for (int qt : {0, 1}) { + // Prompt (prefill): seq_len = total_seqlen + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 16, 8, T, T, 128, qt}); // B=1, GQA ratio 2 + } + // Decode: seq_len=1, past grows + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 16, 8, 1, T, 128, qt}); // B=1, decode + } + // Larger batch decode + b->Args({4, 16, 8, 1, 2048, 128, qt}); + // Flash decoding cases: B*N < thread_count (8), triggers KV partitioning + for (int T : {512, 1024, 2048, 4096}) { + b->Args({1, 4, 4, 1, T, 128, qt}); // B=1, N=4, flash decoding enabled + } + } +} + +BENCHMARK(BM_GQA_Naive)->Apply(FlashGQAArgs)->UseRealTime(); +BENCHMARK(BM_GQA_Flash)->Apply(FlashGQAArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp b/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp index be13d4b489115..5f0b18fa2cac8 100644 --- a/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp +++ b/onnxruntime/test/mlas/unittest/test_qkv_quant.cpp @@ -251,7 +251,7 @@ class MlasKVQuantTest : public MlasTestBase { RefSVGemm(A, BDequant, CRef, M, N, K, K, N); // Quantized: MlasSVGemm - MlasSVGemm(M, N, K, A, K, BQuant, QuantType, scales, C, N, nullptr); + MlasSVGemm(M, N, K, A, K, BQuant, QuantType, scales, C, N, 0.0f, nullptr); float atol = IsInt4(QuantType) ? 0.15f : 0.02f; float rtol = IsInt4(QuantType) ? 0.1f : 0.01f; @@ -322,3 +322,342 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe } return count; }); + +// +// Focused test for MlasFlashAttentionQuantizedKV: +// Validates the tiled online-softmax kernel against a naive reference pipeline +// (MlasQKGemm + softmax + MlasSVGemm) across INT8/INT4, per-tensor/per-channel. +// +class MlasFlashAttentionQuantizedKVTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferQ; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputRef; + MatrixGuardBuffer BufferScores; + MatrixGuardBuffer BufferProbs; + MatrixGuardBuffer BufferScalesK; + MatrixGuardBuffer BufferScalesV; + MatrixGuardBuffer BufferKFP32; + MatrixGuardBuffer BufferVFP32; + MatrixGuardBuffer BufferFlash; + MatrixGuardBuffer BufferPartials; + MatrixGuardBuffer BufferKQuant; + MatrixGuardBuffer BufferVQuant; + + void FillRandom(float* buf, size_t n, unsigned seed, float lo = -0.5f, float hi = 0.5f) { + std::default_random_engine gen(seed); + std::uniform_real_distribution dist(lo, hi); + for (size_t i = 0; i < n; i++) { + buf[i] = dist(gen); + } + } + + bool IsInt4(MLAS_KV_QUANT_TYPE qt) { + return qt == MLAS_KV_QUANT_TYPE::S4_PerTensor || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel; + } + + bool IsPerChannel(MLAS_KV_QUANT_TYPE qt) { + return qt == MLAS_KV_QUANT_TYPE::S8_PerChannel || qt == MLAS_KV_QUANT_TYPE::S4_PerChannel; + } + + void ComputeScales(const float* data, size_t rows, size_t cols, MLAS_KV_QUANT_TYPE qt, float* scales) { + float qmax = IsInt4(qt) ? 7.0f : 127.0f; + if (IsPerChannel(qt)) { + for (size_t c = 0; c < cols; c++) { + float amax = 0.0f; + for (size_t r = 0; r < rows; r++) { + amax = std::max(amax, std::fabs(data[r * cols + c])); + } + scales[c] = (amax > 1e-6f) ? (amax / qmax) : 1.0f; + } + } else { + float amax = 0.0f; + for (size_t i = 0; i < rows * cols; i++) { + amax = std::max(amax, std::fabs(data[i])); + } + scales[0] = (amax > 1e-6f) ? (amax / qmax) : 1.0f; + } + } + + // Naive reference: for a single (batch=1, head=1) attention computation + // Q[seq_len, head_size], K[total_seqlen, head_size], V[total_seqlen, head_size] + // -> output[seq_len, head_size] + // Uses quantized K/V via MlasQKGemm + softmax + MlasSVGemm. + void NaiveReference( + const float* Q, size_t seq_len, size_t total_seqlen, size_t head_size, + const uint8_t* k_quant, const uint8_t* v_quant, + MLAS_KV_QUANT_TYPE quant_type, const float* k_scale, const float* v_scale, + float scale, int past_seqlen, float* output) { + float* scores = BufferScores.GetBuffer(seq_len * total_seqlen); + float* probs = BufferProbs.GetBuffer(seq_len * total_seqlen); + + // QK^T + MlasQKGemm(seq_len, total_seqlen, head_size, scale, + Q, head_size, k_quant, quant_type, k_scale, + scores, total_seqlen, nullptr); + + // Causal mask + softmax + for (size_t q_s = 0; q_s < seq_len; q_s++) { + size_t causal_limit = static_cast(past_seqlen) + q_s + 1; + // Apply causal mask + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + if (kv_s >= causal_limit) { + scores[q_s * total_seqlen + kv_s] = -std::numeric_limits::infinity(); + } + } + // Softmax + float max_val = -std::numeric_limits::infinity(); + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + max_val = std::max(max_val, scores[q_s * total_seqlen + kv_s]); + } + float sum_exp = 0.0f; + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + probs[q_s * total_seqlen + kv_s] = std::exp(scores[q_s * total_seqlen + kv_s] - max_val); + sum_exp += probs[q_s * total_seqlen + kv_s]; + } + for (size_t kv_s = 0; kv_s < total_seqlen; kv_s++) { + probs[q_s * total_seqlen + kv_s] /= sum_exp; + } + } + + // SV GEMM + MlasSVGemm(seq_len, head_size, total_seqlen, + probs, total_seqlen, v_quant, quant_type, v_scale, + output, head_size, 0.0f, nullptr); + } + + void TestFlashAttention(size_t seq_len, size_t total_seqlen, size_t head_size, + MLAS_KV_QUANT_TYPE quant_type) { + const size_t k_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t v_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + const int past_seqlen = static_cast(total_seqlen - seq_len); + + // Allocate and fill + float* Q = BufferQ.GetBuffer(seq_len * head_size); + float* K_fp32 = BufferKFP32.GetBuffer(total_seqlen * head_size); + float* V_fp32 = BufferVFP32.GetBuffer(total_seqlen * head_size); + float* k_scale = BufferScalesK.GetBuffer(k_num_scales); + float* v_scale = BufferScalesV.GetBuffer(v_num_scales); + float* output_flash = BufferOutput.GetBuffer(seq_len * head_size); + float* output_ref = BufferOutputRef.GetBuffer(seq_len * head_size); + + unsigned seed = static_cast(seq_len * 1000 + total_seqlen * 10 + head_size); + FillRandom(Q, seq_len * head_size, seed); + FillRandom(K_fp32, total_seqlen * head_size, seed + 1); + FillRandom(V_fp32, total_seqlen * head_size, seed + 2); + + ComputeScales(K_fp32, total_seqlen, head_size, quant_type, k_scale); + ComputeScales(V_fp32, total_seqlen, head_size, quant_type, v_scale); + + // Quantize K and V + uint8_t* k_quant = BufferKQuant.GetBuffer(total_seqlen * packed_row_bytes); + uint8_t* v_quant = BufferVQuant.GetBuffer(total_seqlen * packed_row_bytes); + MlasKVQuantize(K_fp32, k_quant, total_seqlen, head_size, head_size, quant_type, k_scale, nullptr); + MlasKVQuantize(V_fp32, v_quant, total_seqlen, head_size, head_size, quant_type, v_scale, nullptr); + + // Naive reference + NaiveReference(Q, seq_len, total_seqlen, head_size, + k_quant, v_quant, quant_type, k_scale, v_scale, + scale, past_seqlen, output_ref); + + // Flash attention + int q_block_size = std::min(static_cast(seq_len), 16); + int kv_block_size = std::min(static_cast(total_seqlen), 32); + + size_t buffer_size_per_thread = + (static_cast(q_block_size) * 2 + + static_cast(q_block_size) * static_cast(kv_block_size) + + static_cast(q_block_size) * static_cast(head_size)) * + sizeof(float); + float* flash_buffer = BufferFlash.GetBuffer(buffer_size_per_thread / sizeof(float)); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = 1; + args.kv_num_heads = 1; + args.sequence_length = static_cast(seq_len); + args.total_seqlen = static_cast(total_seqlen); + args.head_size = static_cast(head_size); + args.past_seqlen = past_seqlen; + args.local_window_size = -1; + args.seqlen_present_kv = static_cast(total_seqlen); + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = quant_type; + args.per_channel_k = IsPerChannel(quant_type); + args.per_channel_v = IsPerChannel(quant_type); + args.thread_count = 1; + args.buffer = flash_buffer; + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = k_quant; + args.v_cache = v_quant; + args.k_scale = k_scale; + args.v_scale = v_scale; + args.output = output_flash; + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = nullptr; + args.kv_chunk_count = 0; + + MlasFlashAttentionQuantizedKV(&args, nullptr); + + // Compare: flash uses ComputeSumExpF32Kernel (SIMD polynomial approx) while + // NaiveReference uses std::exp. Tolerance accounts for accumulation order + // differences across platforms/ISAs. + float atol = IsInt4(quant_type) ? 1e-3f : 1e-4f; + for (size_t i = 0; i < seq_len * head_size; i++) { + float diff = std::fabs(output_flash[i] - output_ref[i]); + ASSERT_LE(diff, atol) + << "FlashAttention vs Naive mismatch at [" << i / head_size << ", " << i % head_size + << "], flash=" << output_flash[i] << " ref=" << output_ref[i] + << " seq_len=" << seq_len << " total_seqlen=" << total_seqlen + << " head_size=" << head_size + << " qt=" << static_cast(quant_type); + } + } + + // Test flash decoding path: sequence_length=1 with KV split across chunks + void TestFlashDecoding(size_t total_seqlen, size_t head_size, + MLAS_KV_QUANT_TYPE quant_type) { + const size_t seq_len = 1; + const size_t k_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t v_num_scales = IsPerChannel(quant_type) ? head_size : 1; + const size_t packed_row_bytes = MlasKVQuantPackedRowBytes(quant_type, head_size); + const float scale = 1.0f / std::sqrt(static_cast(head_size)); + const int past_seqlen = static_cast(total_seqlen - 1); + + // Allocate and fill + float* Q = BufferQ.GetBuffer(head_size); + float* K_fp32 = BufferKFP32.GetBuffer(total_seqlen * head_size); + float* V_fp32 = BufferVFP32.GetBuffer(total_seqlen * head_size); + float* k_scale_buf = BufferScalesK.GetBuffer(k_num_scales); + float* v_scale_buf = BufferScalesV.GetBuffer(v_num_scales); + float* output_flash = BufferOutput.GetBuffer(head_size); + float* output_ref = BufferOutputRef.GetBuffer(head_size); + + unsigned seed = static_cast(total_seqlen * 100 + head_size * 7); + FillRandom(Q, head_size, seed); + FillRandom(K_fp32, total_seqlen * head_size, seed + 1); + FillRandom(V_fp32, total_seqlen * head_size, seed + 2); + + ComputeScales(K_fp32, total_seqlen, head_size, quant_type, k_scale_buf); + ComputeScales(V_fp32, total_seqlen, head_size, quant_type, v_scale_buf); + + // Quantize K and V + uint8_t* k_quant = BufferKQuant.GetBuffer(total_seqlen * packed_row_bytes); + uint8_t* v_quant = BufferVQuant.GetBuffer(total_seqlen * packed_row_bytes); + MlasKVQuantize(K_fp32, k_quant, total_seqlen, head_size, head_size, quant_type, k_scale_buf, nullptr); + MlasKVQuantize(V_fp32, v_quant, total_seqlen, head_size, head_size, quant_type, v_scale_buf, nullptr); + + // Naive reference + NaiveReference(Q, seq_len, total_seqlen, head_size, + k_quant, v_quant, quant_type, k_scale_buf, v_scale_buf, + scale, past_seqlen, output_ref); + + // Flash decoding: use small kv_block_size to get multiple chunks + int kv_block_size = std::min(static_cast(total_seqlen), 16); + int kv_chunk_count = (static_cast(total_seqlen) + kv_block_size - 1) / kv_block_size; + + // Per-thread scratch: scores[kv_block_size] + size_t buffer_size_per_thread = static_cast(kv_block_size) * sizeof(float); + float* flash_buffer = BufferFlash.GetBuffer(buffer_size_per_thread / sizeof(float)); + + // Partials buffer: [1 batch * 1 head * kv_chunk_count * (2 + head_size)] + size_t partials_count = static_cast(kv_chunk_count) * (2 + head_size); + float* partials = BufferPartials.GetBuffer(partials_count); + + MlasFlashAttentionQuantizedKVArgs args; + args.batch_size = 1; + args.num_heads = 1; + args.kv_num_heads = 1; + args.sequence_length = 1; + args.total_seqlen = static_cast(total_seqlen); + args.head_size = static_cast(head_size); + args.past_seqlen = past_seqlen; + args.local_window_size = -1; + args.seqlen_present_kv = static_cast(total_seqlen); + args.q_block_size = 1; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.quant_type = quant_type; + args.per_channel_k = IsPerChannel(quant_type); + args.per_channel_v = IsPerChannel(quant_type); + args.thread_count = 1; + args.buffer = flash_buffer; + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.k_cache = k_quant; + args.v_cache = v_quant; + args.k_scale = k_scale_buf; + args.v_scale = v_scale_buf; + args.output = output_flash; + args.attention_bias = nullptr; + args.attention_bias_seqlen_stride = 0; + args.attention_bias_broadcast_batch = true; + args.attention_bias_broadcast_head = true; + args.flash_decoding_partials = partials; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionQuantizedKV(&args, nullptr); + + // Compare: flash decoding has an extra reduce phase with exp rescaling, + // so tolerance is slightly larger than the single-pass flash attention test. + float atol = IsInt4(quant_type) ? 1e-3f : 1e-4f; + for (size_t i = 0; i < head_size; i++) { + float diff = std::fabs(output_flash[i] - output_ref[i]); + ASSERT_LE(diff, atol) + << "FlashDecoding vs Naive mismatch at [" << i + << "], flash=" << output_flash[i] << " ref=" << output_ref[i] + << " total_seqlen=" << total_seqlen + << " head_size=" << head_size + << " qt=" << static_cast(quant_type); + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("FlashAttentionQuantizedKV"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + static const MLAS_KV_QUANT_TYPE AllQuantTypes[] = { + MLAS_KV_QUANT_TYPE::S8_PerTensor, + MLAS_KV_QUANT_TYPE::S8_PerChannel, + MLAS_KV_QUANT_TYPE::S4_PerTensor, + MLAS_KV_QUANT_TYPE::S4_PerChannel, + }; + + for (auto qt : AllQuantTypes) { + size_t min_head = size_t{4}; + for (size_t seq_len : {size_t{1}, size_t{4}, size_t{16}}) { + for (size_t total_seqlen : {size_t{4}, size_t{32}, size_t{64}}) { + if (total_seqlen < seq_len) continue; + for (size_t head_size : {min_head, size_t{32}, size_t{64}}) { + TestFlashAttention(seq_len, total_seqlen, head_size, qt); + } + } + } + // Flash decoding tests (sequence_length=1 with KV split into chunks) + for (size_t total_seqlen : {size_t{4}, size_t{32}, size_t{64}, size_t{128}}) { + for (size_t head_size : {min_head, size_t{32}, size_t{64}}) { + TestFlashDecoding(total_seqlen, head_size, qt); + } + } + } + } +}; + +static UNUSED_VARIABLE bool added_flash_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py new file mode 100644 index 0000000000000..77ac08cf50d6c --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Benchmark CPU GroupQueryAttention: Flash Attention vs Naive (full materialization). + +Runs the actual GQA operator via InferenceSession, toggling between flash and +naive paths using the ORT_GQA_DISABLE_FLASH_ATTENTION environment variable. + +Usage: + python benchmark_gqa_cpu_flash.py + python benchmark_gqa_cpu_flash.py --decode_only + python benchmark_gqa_cpu_flash.py --prompt_only +""" + +import argparse +import os +import time + +import numpy as np +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, SessionOptions + + +def create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with quantized KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + cache_ort_type = TensorProto.UINT8 if bit_width == 4 else TensorProto.INT8 + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "", + "", + "", + "", + "", # cos, sin, position_ids, attention_bias, head_sink + "k_scale", + "v_scale", + ] + while inputs and inputs[-1] == "": + inputs.pop() + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + k_quant_type=quant_type, + v_quant_type=quant_type, + kv_cache_bit_width=bit_width, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info( + "past_value", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, None), + helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, None), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + helper.make_tensor_value_info( + "present_value", cache_ort_type, [batch_size, kv_num_heads, buffer_seq_len, packed_head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +def benchmark_gqa( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + past_seq_len=0, + warmup=5, + repeats=20, +): + """Benchmark a single GQA configuration. Returns elapsed time in ms.""" + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + + total_seqlen = past_seq_len + seq_len + buffer_seq_len = total_seqlen + + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + + sess_options = SessionOptions() + sess_options.intra_op_num_threads = 8 + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + # Generate inputs + np.random.seed(42) + query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) + key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) + total_seq = np.array([total_seqlen], dtype=np.int32) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } + + # Warmup + for _ in range(warmup): + sess.run(None, feeds) + + # Benchmark + start = time.perf_counter() + for _ in range(repeats): + sess.run(None, feeds) + elapsed_ms = (time.perf_counter() - start) / repeats * 1000.0 + + return elapsed_ms + + +def run_benchmarks(args): + """Run flash vs naive benchmarks for various configurations.""" + + configs = [] + + if not args.decode_only: + # Prefill configurations: seq_len = total_seqlen (prompt phase) + for total_seqlen in [512, 1024, 2048, 4096]: + configs.append( + { + "label": f"Prefill S={total_seqlen}", + "batch_size": 1, + "seq_len": total_seqlen, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": 0, + } + ) + + if not args.prompt_only: + # Decode configurations: seq_len=1, varying past + for past_seqlen in [512, 1024, 2048, 4096]: + configs.append( + { + "label": f"Decode T={past_seqlen + 1}", + "batch_size": 1, + "seq_len": 1, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": past_seqlen, + } + ) + + if not args.decode_only and not args.prompt_only: + # Batch decode + configs.append( + { + "label": "Decode B=4 T=2049", + "batch_size": 4, + "seq_len": 1, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 8, + "past_seq_len": 2048, + } + ) + # INT4 prefill + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) + + warmup = args.warmup + repeats = args.repeats + + # Save and restore env var to avoid side effects on callers + saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + + print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") + print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") + print("-" * 62) + + for cfg in configs: + label = cfg.pop("label") + + # Flash path (default) + os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) + flash_ms = benchmark_gqa(**cfg, warmup=warmup, repeats=repeats) + + # Naive path (disabled flash) + os.environ["ORT_GQA_DISABLE_FLASH_ATTENTION"] = "1" + naive_ms = benchmark_gqa(**cfg, warmup=warmup, repeats=repeats) + + speedup = naive_ms / flash_ms if flash_ms > 0 else float("inf") + print(f"{label:<25} {naive_ms:>10.3f}ms {flash_ms:>10.3f}ms {speedup:>8.2f}x") + + # Restore original env state + if saved_env is not None: + os.environ["ORT_GQA_DISABLE_FLASH_ATTENTION"] = saved_env + else: + os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark GQA flash vs naive on CPU") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") + parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") + parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + args = parser.parse_args() + run_benchmarks(args) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py b/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py index 3224a07451534..4a4d3e6ff43e8 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu_quantized.py @@ -108,8 +108,10 @@ def dequantize_int4_per_channel(packed_uint8, scale, kv_num_heads, head_size): # ---- Reference attention ---- -def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, causal=True): - """Reference FP32 GQA: q[B,S,num_heads*H], k[B,N,S_kv,H], v[B,N,S_kv,H] -> out[B,S,num_heads*H].""" +def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, causal=True, attention_bias=None): + """Reference FP32 GQA: q[B,S,num_heads*H], k[B,N,S_kv,H], v[B,N,S_kv,H] -> out[B,S,num_heads*H]. + attention_bias: [B|1, num_heads|1, S, S_kv] or None. + """ batch, seq, _ = q_input.shape s_kv = k_input.shape[2] groups = num_heads // kv_num_heads @@ -128,6 +130,11 @@ def reference_gqa(q_input, k_input, v_input, num_heads, kv_num_heads, head_size, logits = np.zeros(s_kv, dtype=np.float32) for k_s in range(s_kv): logits[k_s] = np.dot(q_bnsh[b, h, q_s], k_input[b, kv_h, k_s]) * scale + # Attention bias + if attention_bias is not None: + bias_b = 0 if attention_bias.shape[0] == 1 else b + bias_h = 0 if attention_bias.shape[1] == 1 else h + logits[:s_kv] += attention_bias[bias_b, bias_h, q_s, :s_kv] # Causal mask if causal: for k_s in range(q_s + 1, s_kv): @@ -244,6 +251,103 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_quantized_gqa_graph_with_bias( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_batch_size, + bias_num_heads, + total_seqlen, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with quantized KV cache and attention bias.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + packed_head_size = head_size // 2 if bit_width == 4 else head_size + + cache_ort_type = TensorProto.UINT8 if bit_width == 4 else TensorProto.INT8 + + past_kv_seqlen = buffer_seq_len + present_kv_seqlen = buffer_seq_len + + # Inputs (attention_bias at index 10) + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "", # cos_cache + "", # sin_cache + "", # position_ids + "attention_bias", + "", # head_sink + "k_scale", + "v_scale", + ] + + # Remove trailing empty strings + while inputs and inputs[-1] == "": + inputs.pop() + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + k_quant_type=quant_type, + v_quant_type=quant_type, + kv_cache_bit_width=bit_width, + domain="com.microsoft", + ) + + # Graph inputs + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", cache_ort_type, [batch_size, kv_num_heads, past_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info( + "past_value", cache_ort_type, [batch_size, kv_num_heads, past_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + helper.make_tensor_value_info( + "attention_bias", TensorProto.FLOAT, [bias_batch_size, bias_num_heads, seq_len, total_seqlen] + ), + helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, None), + helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, None), + ] + + # Graph outputs + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", cache_ort_type, [batch_size, kv_num_heads, present_kv_seqlen, packed_head_size] + ), + helper.make_tensor_value_info( + "present_value", cache_ort_type, [batch_size, kv_num_heads, present_kv_seqlen, packed_head_size] + ), + ] + + graph = helper.make_graph([node], "QuantizedGQA_Bias_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + # ---- Test runner ---- @@ -517,5 +621,253 @@ def test_int4_long_sequence(self): ) +def run_quantized_gqa_bias_test( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_broadcast_batch=False, + bias_broadcast_head=False, + atol=None, +): + """Run a quantized GQA test with attention bias and compare against reference.""" + np.random.seed(123) + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) + key_input = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + value_input = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) + + # Reshape K/V to BNSH for quantization + k_bnsh = key_input.reshape(batch_size, seq_len, kv_num_heads, head_size).transpose(0, 2, 1, 3) + v_bnsh = value_input.reshape(batch_size, seq_len, kv_num_heads, head_size).transpose(0, 2, 1, 3) + + # Compute scales + if bit_width == 8: + if quant_type == "PER_TENSOR": + _, k_scale = quantize_int8_per_tensor(k_bnsh) + _, v_scale = quantize_int8_per_tensor(v_bnsh) + else: + _, k_scale = quantize_int8_per_channel(k_bnsh) + _, v_scale = quantize_int8_per_channel(v_bnsh) + else: + if quant_type == "PER_TENSOR": + _, k_scale = quantize_int4_per_tensor(k_bnsh) + _, v_scale = quantize_int4_per_tensor(v_bnsh) + else: + _, k_scale = quantize_int4_per_channel(k_bnsh) + _, v_scale = quantize_int4_per_channel(v_bnsh) + + # Empty past (prompt) + packed_head_size = head_size // 2 if bit_width == 4 else head_size + if bit_width == 4: + past_k = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.uint8) + past_v = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.uint8) + else: + past_k = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.int8) + past_v = np.zeros((batch_size, kv_num_heads, seq_len, packed_head_size), dtype=np.int8) + + seqlens_k = np.array([seq_len - 1] * batch_size, dtype=np.int32) + total_seq = np.array([seq_len], dtype=np.int32) + + # Generate attention bias + bias_batch = 1 if bias_broadcast_batch else batch_size + bias_heads = 1 if bias_broadcast_head else num_heads + attention_bias = np.random.uniform(-1.0, 1.0, (bias_batch, bias_heads, seq_len, seq_len)).astype(np.float32) + + # Build and run ONNX model + onnx_model_str = create_quantized_gqa_graph_with_bias( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + bias_batch_size=bias_batch, + bias_num_heads=bias_heads, + total_seqlen=seq_len, + ) + sess_options = SessionOptions() + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + feeds = { + "query": query, + "key": key_input, + "value": value_input, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "attention_bias": attention_bias, + "k_scale": k_scale, + "v_scale": v_scale, + } + + outputs = sess.run(None, feeds) + out_ort = outputs[0] + + # Compute reference with quantized K/V + if bit_width == 8 and quant_type == "PER_TENSOR": + k_q = np.clip(np.round(k_bnsh / k_scale[0]), -128, 127).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale[0]), -128, 127).astype(np.int8) + k_deq = dequantize_int8_per_tensor(k_q, k_scale[0]) + v_deq = dequantize_int8_per_tensor(v_q, v_scale[0]) + elif bit_width == 8 and quant_type == "PER_CHANNEL": + k_q = np.clip(np.round(k_bnsh / k_scale.reshape(1, kv_num_heads, 1, head_size)), -128, 127).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale.reshape(1, kv_num_heads, 1, head_size)), -128, 127).astype(np.int8) + k_deq = dequantize_int8_per_channel(k_q, k_scale, kv_num_heads, head_size) + v_deq = dequantize_int8_per_channel(v_q, v_scale, kv_num_heads, head_size) + elif bit_width == 4 and quant_type == "PER_TENSOR": + k_q = np.clip(np.round(k_bnsh / k_scale[0]), -8, 7).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale[0]), -8, 7).astype(np.int8) + k_deq = k_q.astype(np.float32) * k_scale[0] + v_deq = v_q.astype(np.float32) * v_scale[0] + elif bit_width == 4 and quant_type == "PER_CHANNEL": + k_q = np.clip(np.round(k_bnsh / k_scale.reshape(1, kv_num_heads, 1, head_size)), -8, 7).astype(np.int8) + v_q = np.clip(np.round(v_bnsh / v_scale.reshape(1, kv_num_heads, 1, head_size)), -8, 7).astype(np.int8) + k_deq = k_q.astype(np.float32) * k_scale.reshape(1, kv_num_heads, 1, head_size) + v_deq = v_q.astype(np.float32) * v_scale.reshape(1, kv_num_heads, 1, head_size) + else: + raise ValueError(f"Unsupported config: bit_width={bit_width}, quant_type={quant_type}") + + out_ref = reference_gqa( + query, k_deq, v_deq, num_heads, kv_num_heads, head_size, causal=True, attention_bias=attention_bias + ) + + if atol is None: + atol = 0.15 if bit_width == 4 else 0.05 + + if np.any(np.isnan(out_ort)): + raise AssertionError(f"NaN in output (quant={quant_type}, bit={bit_width}, bias test)") + if np.allclose(out_ort, 0.0): + raise AssertionError(f"Output is all zeros (quant={quant_type}, bit={bit_width}, bias test)") + + np.testing.assert_allclose( + out_ort, + out_ref, + atol=atol, + rtol=0.1, + err_msg=f"Quantized GQA + bias mismatch (quant={quant_type}, bit={bit_width})", + ) + + +class TestGQACPUQuantizedKVWithBias(unittest.TestCase): + """Test CPU GroupQueryAttention with quantized KV cache and attention bias.""" + + def test_int8_per_tensor_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + ) + + def test_int8_per_channel_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_CHANNEL", + bit_width=8, + ) + + def test_int4_per_tensor_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_TENSOR", + bit_width=4, + ) + + def test_int4_per_channel_bias(self): + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=2, + kv_num_heads=1, + head_size=16, + quant_type="PER_CHANNEL", + bit_width=4, + ) + + def test_int8_bias_broadcast_batch(self): + """Bias shape [1, N, S, T] with batch_size > 1.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_batch=True, + ) + + def test_int8_bias_broadcast_head(self): + """Bias shape [B, 1, S, T] with num_heads > 1.""" + run_quantized_gqa_bias_test( + batch_size=1, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_head=True, + ) + + def test_int8_bias_broadcast_both(self): + """Bias shape [1, 1, S, T] with batch_size > 1 and num_heads > 1.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=8, + num_heads=4, + kv_num_heads=2, + head_size=16, + quant_type="PER_TENSOR", + bit_width=8, + bias_broadcast_batch=True, + bias_broadcast_head=True, + ) + + def test_int8_bias_large(self): + """Larger test to exercise flash attention path with bias.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=32, + num_heads=4, + kv_num_heads=2, + head_size=64, + quant_type="PER_TENSOR", + bit_width=8, + ) + + def test_int4_bias_large(self): + """Larger test with INT4 to exercise flash attention path with bias.""" + run_quantized_gqa_bias_test( + batch_size=2, + seq_len=32, + num_heads=4, + kv_num_heads=2, + head_size=64, + quant_type="PER_CHANNEL", + bit_width=4, + ) + + if __name__ == "__main__": unittest.main() From abbc2abc646e1f45e8534cec5991abdcabf9e889 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 29 May 2026 02:41:48 +0000 Subject: [PATCH 07/10] Add `tools/python/compile_contributors.py` `--paths` option (#28710) ### Description Add a new `--paths` option to `compile_contributors.py` to limit git history queries using pathspecs. Apply the path filter to both base and target git log collection and log the active path filter in logs.txt. ### Motivation and Context Allow `compile_contributors.py` to be used for releases where relevant changes are largely limited to a subset of the codebase. E.g., we can limit the paths to WebGPU EP-related files for the WebGPU plugin EP release. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tools/python/compile_contributors.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tools/python/compile_contributors.py b/tools/python/compile_contributors.py index bb02c2807d08c..92ba59747493e 100644 --- a/tools/python/compile_contributors.py +++ b/tools/python/compile_contributors.py @@ -11,10 +11,16 @@ Usage: python compile_contributors.py [--base ] [--target ] [--dir ] + [--paths [ ...]] Example: python compile_contributors.py --base origin/rel-1.23.2 --target origin/rel-1.24.1 --dir rel-1.24.1_report + # Limit to commits that touch selected areas (replace with your paths): + # Using git pathspec syntax, ":(top)" anchors each path at repository root. + python compile_contributors.py --base origin/main~500 --target origin/main \ + --paths ":(top)path/to/component_a" ":(top)path/to/component_b" + Outputs: - detail.csv: Detailed breakdown of PRs, authors, and commit links. - logs.txt: Processing logs and summary (professional humans-only contributor list for release notes). @@ -314,6 +320,19 @@ def main(): parser.add_argument("--target", default="origin/rel-1.24.1", help="Target branch/commit to compare to") parser.add_argument("--dir", default="contributors", help="Output directory for reports and logs") parser.add_argument("--scan-depth", type=int, default=200, help="Depth to scan base/meta-PRs for deduplication") + parser.add_argument( + "--paths", + nargs="+", + default=None, + metavar="PATH", + help=( + "Optional list of paths (git pathspec) to limit history to. " + "Only commits that touch one of these paths are considered. " + "Note: when a 'Cherry-pick round' meta-PR is included because at " + "least one of its cherry-picks touched these paths, all its " + "sub-PRs are still expanded regardless of paths." + ), + ) args = parser.parse_args() # Early validation @@ -324,6 +343,9 @@ def main(): branch_target = args.target output_dir = args.dir scan_depth = args.scan_depth + # Build a pathspec suffix (e.g. ["--", "onnxruntime/core/providers/webgpu", ...]) once, + # so it can be appended to each `git log` invocation below. + paths_args = (["--", *args.paths]) if args.paths else [] if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -331,10 +353,12 @@ def main(): logs_path = os.path.join(output_dir, "logs.txt") with open(logs_path, "w", encoding="utf-8") as log_file: log_event(f"Starting comparison: {branch_base} -> {branch_target}", log_file) + if args.paths: + log_event(f"Limiting history to paths: {args.paths}", log_file) # 1. Fetch base branch PRs (scan depth controlled by scan_depth) log_event(f"Fetching base branch history for {branch_base} (last {scan_depth})...", log_file) - log_base = run_command(["git", "log", branch_base, "-n", str(scan_depth), "--oneline"]) + log_base = run_command(["git", "log", branch_base, "-n", str(scan_depth), "--oneline", *paths_args]) if log_base is None: log_event( f"Error: Could not fetch history for base ref '{branch_base}'. Please check if the ref exists.", @@ -348,7 +372,7 @@ def main(): # 2. Fetch target branch PRs (only those not in base) log_event(f"Fetching target branch history: {branch_base}..{branch_target}...", log_file) # Using A..B syntax for git log - log_target = run_command(["git", "log", f"{branch_base}..{branch_target}", "--oneline"]) + log_target = run_command(["git", "log", f"{branch_base}..{branch_target}", "--oneline", *paths_args]) if log_target is None: log_event( f"Error: Could not fetch history for range '{branch_base}..{branch_target}'. Please check if the refs exist.", From 97e7b2a7c371a14418b349027fbe2756230fb09b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 29 May 2026 12:34:52 -0400 Subject: [PATCH 08/10] Update CUDA 12.8 to 13.0 in CI workflows (#28458) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Update CUDA version from 12.8 to 13.0 across all CUDA CI workflow files: - `linux_cuda_ci.yml` — `--cuda_version`, `--cuda_home`, `--cudnn_home` in build and test jobs - `linux_cuda_plugin_ci.yml` — same flags - `windows_cuda.yml` — SDK download URL, PATH entries, `--cuda_home` in build and test jobs - `windows_cuda_plugin.yml` — same as above ### Motivation and Context PRs that break CUDA 13 builds pass CI today because all pipelines target CUDA 12.8. This moves CI to CUDA 13.0 for build-time coverage. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- .github/workflows/linux_cuda_ci.yml | 21 +++++++++++----- .github/workflows/linux_cuda_plugin_ci.yml | 28 ++++++++++++++++------ .github/workflows/windows_cuda.yml | 8 +++++++ .github/workflows/windows_cuda_plugin.yml | 8 +++++++ cmake/onnxruntime_python.cmake | 1 + 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index f50c0064dd956..89dcf718e8bf5 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -27,9 +27,9 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' - docker_image_repo: onnxruntimecuda12manylinuxbuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' + docker_image_repo: onnxruntimecuda13manylinuxbuild + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 --cuda_version=13.0 --cuda_home=/usr/local/cuda-13.0 --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -57,8 +57,8 @@ jobs: id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda12manylinuxbuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda13manylinuxbuild + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -91,6 +91,15 @@ jobs: echo "Warning: perms.txt not found in artifact." fi + # Verify the GPU is accessible inside Docker before running the full test suite. + # If the NVIDIA Container Toolkit fails to expose /dev/nvidia* devices, + # tests will fail with "CUDA failure 100" and waste 10+ minutes. + - name: Verify GPU access in Docker + run: | + docker run --rm --gpus all \ + "${{ steps.build_docker_image_step.outputs.full-image-name }}" \ + nvidia-smi + # --- Run Tests using the downloaded build --- # The run-build-script-in-docker action mounts ${{ runner.temp }} to /onnxruntime_src/build # So build.py --build_dir build/Release inside the container correctly finds the artifacts. @@ -102,5 +111,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda' - extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=13.0 --cuda_home=/usr/local/cuda-13.0 --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index d2491f59812ab..a9197b3732dd8 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -26,17 +26,17 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' - docker_image_repo: onnxruntimecuda12manylinuxbuild + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' + docker_image_repo: onnxruntimecuda13manylinuxbuild extra_build_flags: >- --use_binskim_compliant_compile_flags --build_wheel --parallel --nvcc_threads 4 --flash_nvcc_threads 4 - --cuda_version=12.8 - --cuda_home=/usr/local/cuda-12.8 - --cudnn_home=/usr/local/cuda-12.8 + --cuda_version=13.0 + --cuda_home=/usr/local/cuda-13.0 + --cudnn_home=/usr/local/cuda-13.0 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 --cmake_extra_defines onnxruntime_QUICK_BUILD=ON @@ -67,8 +67,8 @@ jobs: id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda12manylinuxbuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda13manylinuxbuild + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -100,6 +100,15 @@ jobs: echo "Warning: perms.txt not found in artifact." fi + # Verify the GPU is accessible inside Docker before running the full test suite. + # If the NVIDIA Container Toolkit fails to expose /dev/nvidia* devices, + # tests will fail with "CUDA failure 100" and waste 10+ minutes. + - name: Verify GPU access in Docker + run: | + docker run --rm --gpus all \ + "${{ steps.build_docker_image_step.outputs.full-image-name }}" \ + nvidia-smi + # --- Install the ORT wheel and run CUDA plugin EP tests --- - name: Run CUDA Plugin EP Python Tests run: | @@ -111,6 +120,11 @@ jobs: bash -c " set -ex export PATH=/opt/python/cp312-cp312/bin:\$PATH + # Ensure libcudart.so.13 is findable regardless of host-runner NVIDIA Container Toolkit configuration. + # The CUDA runtime library lives in the container image at /usr/local/cuda-13.0/lib64, but the + # LD_LIBRARY_PATH may not include this path when the runner's NVIDIA toolkit only mounts driver + # libraries at /usr/local/nvidia/lib64. + export LD_LIBRARY_PATH=/usr/local/cuda-13.0/lib64:\${LD_LIBRARY_PATH:-} # Install the ORT wheel python -m pip install /build/Release/Release/dist/onnxruntime*.whl diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 53c7031c3c095..dcc314084e4e2 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -157,6 +157,7 @@ jobs: runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "1ES.ImageOverride=onnxruntime-Win-CPU-VS2022-Latest-NVMe-x64-test", "JobId=windows-cuda-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" ] steps: @@ -222,6 +223,13 @@ jobs: with: whl-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo\dist + # Verify the GPU is accessible before running the full test suite. + # If the NVIDIA driver is not available, tests will fail with + # "CUDA failure 100" and waste significant time. + - name: Verify GPU access + shell: pwsh + run: nvidia-smi + - name: Run Tests working-directory: ${{ runner.temp }} run: | diff --git a/.github/workflows/windows_cuda_plugin.yml b/.github/workflows/windows_cuda_plugin.yml index f9acdbd76a12d..6b6b7f7158df3 100644 --- a/.github/workflows/windows_cuda_plugin.yml +++ b/.github/workflows/windows_cuda_plugin.yml @@ -127,6 +127,7 @@ jobs: runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", + "1ES.ImageOverride=onnxruntime-Win-CPU-VS2022-Latest-NVMe-x64-test", "JobId=windows-cuda-plugin-test-${{ github.run_id }}-${{ github.run_number }}-${{ github.run_attempt }}" ] steps: @@ -187,6 +188,13 @@ jobs: with: whl-directory: ${{ runner.temp }}\build\Release\Release\dist + # Verify the GPU is accessible before running the full test suite. + # If the NVIDIA driver is not available, tests will fail with + # "CUDA failure 100" and waste significant time. + - name: Verify GPU access + shell: pwsh + run: nvidia-smi + - name: Run CUDA Plugin EP Python Tests working-directory: ${{ github.workspace }}\onnxruntime\test\python\transformers shell: pwsh diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index cbd4a38ae18f0..de1d7559a1572 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -242,6 +242,7 @@ if (onnxruntime_USE_CUDA AND NOT WIN32) ) include(cutlass) target_include_directories(onnxruntime_pybind11_state PRIVATE ${cutlass_SOURCE_DIR}/include) + target_link_libraries(onnxruntime_pybind11_state PRIVATE CUDA::cudart) endif() if (onnxruntime_USE_CUDA AND WIN32) target_compile_definitions(onnxruntime_pybind11_state PRIVATE ORT_NO_CUDA_IN_PYBIND) From 6a517f55ea4b0cccb1b5e370d01f8b03a5b52da9 Mon Sep 17 00:00:00 2001 From: kpkbandi Date: Fri, 29 May 2026 12:40:59 -0700 Subject: [PATCH 09/10] Add explicit CUDA version suffix to GPU release artifacts (#28691) ### Description For ORT 1.27, GPU release artifacts (zip/tgz) now include an explicit CUDA major version suffix to distinguish between CUDA 12 and CUDA 13 builds. **Before:** `onnxruntime-linux-x64-gpu-1.27.0.tgz`, `onnxruntime-win-x64-gpu-1.27.0.zip` **After:** `onnxruntime-linux-x64-gpu_cuda12-1.27.0.tgz`, `onnxruntime-win-x64-gpu_cuda13-1.27.0.tgz`, etc. ### Motivation and Context --------- Co-authored-by: Kusuma Padma Kavya Bandi --- .../nuget-linux-cuda-packaging-stage.yml | 14 ++++---- .../stages/nuget-win-cuda-packaging-stage.yml | 12 ++++--- .../linux/extract_and_bundle_gpu_package.sh | 36 +++++++++++-------- .../github/windows/extract_zip_files_gpu.ps1 | 15 +++++--- tools/nuget/validate_package.py | 10 ++---- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 2d7895510afeb..bd5290e8f792c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -183,23 +183,23 @@ stages: displayName: 'Shell Script' inputs: scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' - args: '-a $(Build.BinariesDirectory)/tgz-artifacts' + args: '-a $(Build.BinariesDirectory)/tgz-artifacts -c $(CUDA_VERSION_MAJOR)' workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' - task: ArchiveFiles@2 inputs: - rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' + rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' includeRootFolder: false archiveType: 'tar' # Options: zip, 7z, tar, wim tarCompression: 'gz' - archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).tgz' replaceExistingArchive: true - template: ../templates/validate-package.yml parameters: PackageType: 'tarball' PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + PackageName: 'onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).tgz' ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' PlatformsSupported: 'linux-x64' VerifyNugetSigning: false @@ -214,10 +214,12 @@ stages: script: | docker run -e SYSTEM_COLLECTIONURI --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/models:/data/models --volume $(Build.SourcesDirectory):/src_dir \ --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build \ - /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet + /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu_cuda${CUDA_VERSION_MAJOR}-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet workingDirectory: '$(Build.ArtifactStagingDirectory)' + env: + CUDA_VERSION_MAJOR: $(CUDA_VERSION_MAJOR) - task: 1ES.PublishPipelineArtifact@1 inputs: targetPath: '$(Build.ArtifactStagingDirectory)' - artifactName: 'onnxruntime-linux-x64-gpu' + artifactName: 'onnxruntime-linux-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index b072e22818eec..0e73dff34aa6a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -122,6 +122,10 @@ stages: variables: CUDA_MODULE_LOADINGL: 'LAZY' GRADLE_OPTS: '-Dorg.gradle.daemon=false' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + CUDA_VERSION_MAJOR: '13' + ${{ if eq(parameters.CudaVersion, '12.8') }}: + CUDA_VERSION_MAJOR: '12' steps: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples @@ -181,14 +185,14 @@ stages: displayName: 'Copy zip file to: $(Build.ArtifactStagingDirectory)' inputs: SourceFolder: '$(Build.BinariesDirectory)\zip-artifacts' - Contents: 'onnxruntime-win-x64-gpu-*.zip' + Contents: 'onnxruntime-win-x64-gpu_cuda*-*.zip' TargetFolder: '$(Build.ArtifactStagingDirectory)' - template: ../templates/validate-package.yml parameters: PackageType: 'zip' PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' + PackageName: 'onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).zip' ScriptPath: '$(Build.SourcesDirectory)\onnxruntime\tools\nuget\validate_package.py' PlatformsSupported: 'win-x64' VerifyNugetSigning: false @@ -200,11 +204,11 @@ stages: condition: and(succeeded(), ne(${{parameters.CudaVersion}}, '13.0')) inputs: filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat - arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet + arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet workingFolder: '$(Build.ArtifactStagingDirectory)' - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Pipeline Combined GPU Package Artifact' inputs: - artifactName: 'onnxruntime-win-x64-gpu' + artifactName: 'onnxruntime-win-x64-gpu_cuda$(CUDA_VERSION_MAJOR)' targetPath: '$(Build.ArtifactStagingDirectory)' diff --git a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh index 04ac0e35a6d78..998e71c20539c 100755 --- a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh +++ b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh @@ -1,30 +1,36 @@ #!/bin/bash set -e -o -x -while getopts a: parameter_Option +while getopts a:c: parameter_Option do case "${parameter_Option}" in a) ARTIFACT_DIR=${OPTARG};; +c) CUDA_MAJOR=${OPTARG};; +*) echo "Unknown option"; exit 1;; esac done -EXIT_CODE=1 +if [ -z "$CUDA_MAJOR" ]; then + echo "Error: CUDA major version (-c) is required" + exit 1 +fi uname -a -cd $ARTIFACT_DIR +cd "$ARTIFACT_DIR" -mkdir -p $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt -tar zxvf $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt-*.tgz -C onnxruntime-linux-x64-tensorrt -rm $ARTIFACT_DIR/onnxruntime-linux-x64-tensorrt-*.tgz +mkdir -p "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt +tar zxvf "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt-*.tgz -C onnxruntime-linux-x64-tensorrt +rm "$ARTIFACT_DIR"/onnxruntime-linux-x64-tensorrt-*.tgz -# Rename cuda directory to gpu directory -mkdir -p $ARTIFACT_DIR/onnxruntime-linux-x64-gpu -tar zxvf $ARTIFACT_DIR/onnxruntime-linux-x64-cuda-*.tgz -C onnxruntime-linux-x64-gpu -VERSION=`ls $ARTIFACT_DIR/onnxruntime-linux-x64-gpu | sed 's/onnxruntime-linux-x64-cuda-//'` -mv $ARTIFACT_DIR/onnxruntime-linux-x64-gpu/* $ARTIFACT_DIR/onnxruntime-linux-x64-gpu/onnxruntime-linux-x64-gpu-$VERSION -rm $ARTIFACT_DIR/onnxruntime-linux-x64-cuda-*.tgz +# Rename cuda directory to gpu_cuda{MAJOR} directory +GPU_DIR_NAME="onnxruntime-linux-x64-gpu_cuda${CUDA_MAJOR}" +mkdir -p "$ARTIFACT_DIR"/"$GPU_DIR_NAME" +tar zxvf "$ARTIFACT_DIR"/onnxruntime-linux-x64-cuda-*.tgz -C "$GPU_DIR_NAME" +VERSION=$(find "$ARTIFACT_DIR"/"$GPU_DIR_NAME" -maxdepth 1 -mindepth 1 -printf '%f\n' | sed 's/onnxruntime-linux-x64-cuda-//') +mv "$ARTIFACT_DIR"/"$GPU_DIR_NAME"/* "$ARTIFACT_DIR"/"$GPU_DIR_NAME"/"${GPU_DIR_NAME}-${VERSION}" +rm "$ARTIFACT_DIR"/onnxruntime-linux-x64-cuda-*.tgz -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime.so* onnxruntime-linux-x64-gpu/*/lib -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_tensorrt.so onnxruntime-linux-x64-gpu/*/lib -cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_shared.so onnxruntime-linux-x64-gpu/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime.so* "$GPU_DIR_NAME"/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_tensorrt.so "$GPU_DIR_NAME"/*/lib +cp onnxruntime-linux-x64-tensorrt/*/lib/libonnxruntime_providers_shared.so "$GPU_DIR_NAME"/*/lib diff --git a/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 b/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 index 6671fecfbe072..0e082bbdde531 100644 --- a/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_zip_files_gpu.ps1 @@ -1,8 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +$CudaMajor = $Env:CUDA_VERSION_MAJOR +if (-not $CudaMajor) { + Write-Error "CUDA_VERSION_MAJOR environment variable is required" + exit 1 +} + # extract *-cuda-*.zip and *-tensorrt-*.zip -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *.zip | +Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *.zip | Foreach-Object { $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\zip-artifacts" Write-Output $cmd @@ -13,13 +19,14 @@ Foreach-Object { Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts | Where-Object { $_.Name -match 'onnxruntime-win-x64-tensorrt-\d{1,}\.\d{1,}\.\d{1,}$' } | Rename-Item -NewName $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-tensorrt Remove-Item $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\*.zip -# Rename cuda directory to gpu directory and re-compress it for later use in bundle_dlls_gpu.bat +# Rename cuda directory to gpu_cuda{MAJOR} directory and re-compress it for later use in bundle_dlls_gpu.bat Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\zip-artifacts -Filter *cuda* | Foreach-Object { $($_.FullName) -match '.*onnxruntime-win-x64-cuda-(.*)' $version=$matches[1] - Rename-Item -Path $($_.FullName) -NewName onnxruntime-win-x64-gpu-$version - $cmd = "7z.exe a $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-gpu-$version.zip $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\onnxruntime-win-x64-gpu-$version" + $gpuName = "onnxruntime-win-x64-gpu_cuda${CudaMajor}-$version" + Rename-Item -Path $($_.FullName) -NewName $gpuName + $cmd = "7z.exe a $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\${gpuName}.zip $Env:BUILD_BINARIESDIRECTORY\zip-artifacts\${gpuName}" Write-Output $cmd Invoke-Expression -Command $cmd } diff --git a/tools/nuget/validate_package.py b/tools/nuget/validate_package.py index 59e88ea15e7c6..44951c9c3194f 100644 --- a/tools/nuget/validate_package.py +++ b/tools/nuget/validate_package.py @@ -232,10 +232,7 @@ def validate_tarball(args): raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name - if "-gpu-" in package_name.lower(): - is_gpu_package = True - else: - is_gpu_package = False + is_gpu_package = "-gpu_cuda" in package_name.lower() package_folder = re.search("(.*)[.].*", package_name).group(1) @@ -266,10 +263,7 @@ def validate_zip(args): raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name - if "-gpu-" in package_name.lower(): - is_gpu_package = True - else: - is_gpu_package = False + is_gpu_package = "-gpu_cuda" in package_name.lower() package_folder = re.search("(.*)[.].*", package_name).group(1) From d165fba0abb86fd546ff91c79e8aa4943e6df249 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 29 May 2026 12:48:54 -0700 Subject: [PATCH 10/10] CUDA Plugin EP: NHWC Cleanup & Hardening (#28612) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Unifies the NHWC-eligible op allowlist between the bundled CUDA EP and the CUDA plugin EP into a single shared header, adds kernel-miss diagnostics, and expands NHWC test coverage from 4 ops to 11. ## Motivation The bundled EP (`cuda_execution_provider.cc`) and the plugin EP (`plugin/cuda_ep.cc`) independently maintained their own copies of the NHWC allowlist. This created a maintenance hazard where ops could be added to one but not the other, leading to silent divergence. Additionally, there was no runtime diagnostic when the framework rewrote a node to the NHWC domain but the plugin EP lacked a matching kernel — failures were silent fallbacks to CPU. ## Key Changes ### Shared NHWC Allowlist (`cuda_nhwc_ops.h`) | Item | Detail | |------|--------| | New file | `onnxruntime/core/providers/cuda/cuda_nhwc_ops.h` | | Contents | `IsNhwcEligibleOnnxOp()`, `IsNhwcEligibleMsOp()`, `IsNhwcEligible()` inline functions | | Ops covered | AveragePool, BatchNormalization, Conv, ConvTranspose, DepthToSpace, GlobalAveragePool, GlobalMaxPool, GridSample, LRN, MaxPool, SpaceToDepth (+ MS-domain GridSample) | ### Bundled EP Refactor (`cuda_execution_provider.cc`) - Removed the static `std::unordered_set cuda_nhwc_onnx_ops` and the inline domain check logic. - Replaced with a single call to `cuda::IsNhwcEligible(node_domain, node_op_type)`. ### Plugin EP Refactor & Diagnostics (`plugin/cuda_ep.cc`) - `ShouldConvertDataLayoutForOpImpl`: Replaced ~20 lines of static set + domain checks with a single `cuda::IsNhwcEligible()` call. - `GetCapabilityImpl`: Added a WARNING-level diagnostic in the `else` branch (kernel not found). When a node in the `com.ms.internal.nhwc` domain has no registered kernel, the log emits the op type, domain, version, and node name — making future NHWC registration gaps immediately visible at session creation. ### Expanded NHWC Test Coverage (`test_cuda_plugin_ep.py`) - Added `_assert_nhwc_domain_assigned()` helper that verifies NHWC layout transformation occurred by checking for framework-inserted Transpose nodes in the EP's assignment info. - Added `_run_nhwc_model_test()` helper combining domain assertion + numerical validation. - Updated 4 existing NHWC tests (Conv, BatchNormalization, MaxPool, AveragePool) to include structural assertions. - Added 7 new NHWC test methods: - `test_nhwc_conv_transpose` - `test_nhwc_global_max_pool` - `test_nhwc_global_average_pool` - `test_nhwc_depth_to_space` - `test_nhwc_space_to_depth` - `test_nhwc_lrn` - `test_nhwc_grid_sample` ## Testing Notes Run the full CUDA plugin EP test suite with NHWC enabled: ```bash bash .env/cuda13_plugin.sh --build --install --test_plugin ``` Or run only the NHWC tests directly: ```bash cd onnxruntime/test/python/transformers ORT_TEST_CUDA_PLUGIN_EP=1 python -m unittest \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_conv \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_batch_normalization \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_maxpool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_avgpool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_conv_transpose \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_global_max_pool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_global_average_pool \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_depth_to_space \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_space_to_depth \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_lrn \ test_cuda_plugin_ep.TestCudaPluginEP.test_nhwc_grid_sample ``` All 86 tests in the suite pass (11 NHWC + 75 existing), with no regressions. --- .../providers/cuda/cuda_execution_provider.cc | 19 +- .../core/providers/cuda/cuda_nhwc_ops.h | 48 +++ .../core/providers/cuda/plugin/cuda_ep.cc | 46 +-- .../transformers/test_cuda_plugin_ep.py | 291 +++++++++++++++++- 4 files changed, 355 insertions(+), 49 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/cuda_nhwc_ops.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d9b5760848678..6fd53220a0180 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -12,6 +12,7 @@ #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_nhwc_ops.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" @@ -383,23 +384,7 @@ std::optional CUDAExecutionProvider::ShouldConvertDataLayoutForOp([[maybe_ return std::nullopt; } - // TODO(mtavenrath) generate list from registered kernels using nhwc domain - static const std::unordered_set cuda_nhwc_onnx_ops{ - "BatchNormalization", - "Conv", - "ConvTranspose", - "GlobalMaxPool", - "MaxPool", - "GlobalAveragePool", - "AveragePool", - "GridSample", - "DepthToSpace", - "SpaceToDepth", - "LRN", - }; - - return (node_domain == kOnnxDomain && cuda_nhwc_onnx_ops.find(node_op_type) != cuda_nhwc_onnx_ops.end()) || - (node_domain == kMSDomain && node_op_type == "GridSample"); + return cuda::IsNhwcEligible(node_domain, node_op_type); #else // defined(ENABLE_CUDA_NHWC_OPS) ORT_UNUSED_PARAMETER(node_domain); diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h b/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h new file mode 100644 index 0000000000000..e4fe232e2362e --- /dev/null +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_ops.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace cuda { + +// Unified allowlist of ops eligible for NHWC layout conversion in both the +// bundled CUDA EP and the CUDA plugin EP. Maintaining a single source of truth +// prevents silent divergence between the two implementations. + +inline bool IsNhwcEligibleOnnxOp(std::string_view op_type) { + // Alphabetical order for easy maintenance. + return op_type == "AveragePool" || + op_type == "BatchNormalization" || + op_type == "Conv" || + op_type == "ConvTranspose" || + op_type == "DepthToSpace" || + op_type == "GlobalAveragePool" || + op_type == "GlobalMaxPool" || + op_type == "GridSample" || + op_type == "LRN" || + op_type == "MaxPool" || + op_type == "SpaceToDepth"; +} + +inline bool IsNhwcEligibleMsOp(std::string_view op_type) { + return op_type == "GridSample"; +} + +// Returns true if the given (domain, op_type) pair is eligible for NHWC +// conversion. |domain| should be kOnnxDomain ("") or kMSDomain +// ("com.microsoft"). +inline bool IsNhwcEligible(std::string_view domain, std::string_view op_type) { + if (domain.empty()) { + return IsNhwcEligibleOnnxOp(op_type); + } + if (domain == "com.microsoft") { + return IsNhwcEligibleMsOp(op_type); + } + return false; +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 6a1a1b8698b4d..1212f8ed77170 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -20,6 +20,7 @@ #include #include "core/graph/constants.h" +#include "core/providers/cuda/cuda_nhwc_ops.h" namespace onnxruntime { namespace cuda_plugin { @@ -214,7 +215,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( tentative_nodes.reserve(all_nodes.size()); for (const auto& node : all_nodes) { - std::string ep_name = node.GetEpName(); + const std::string& ep_name = node.GetEpName(); if (!ep_name.empty()) { if (ep_name == ep->name_) { candidate_nodes.push_back(node); @@ -229,6 +230,18 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( if (kernel_def != nullptr) { candidate_nodes.push_back(node); tentative_nodes.push_back(node); + } else { + // Emit a diagnostic when an NHWC-domain node has no matching kernel. + // This helps identify gaps between the layout conversion allowlist and + // the actually-registered NHWC kernels in the plugin build. + const std::string& node_domain = node.GetDomain(); + if (node_domain == kMSInternalNHWCDomain) { + ORT_CXX_LOGF(Ort::Logger(&ep->logger_), ORT_LOGGING_LEVEL_WARNING, + "NHWC kernel miss: op=%s domain=%s version=%d node=%s - " + "no matching kernel registered in the CUDA plugin EP.", + node.GetOperatorType().c_str(), node_domain.c_str(), + node.GetSinceVersion(), node.GetName().c_str()); + } } } @@ -308,36 +321,11 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( return nullptr; } - // ONNX domain ops that have NHWC kernel registrations. - static const std::unordered_set cuda_nhwc_onnx_ops{ - "BatchNormalization", - "Conv", - "ConvTranspose", - "GlobalMaxPool", - "MaxPool", - "GlobalAveragePool", - "AveragePool", - "GridSample", - "DepthToSpace", - "SpaceToDepth", - "LRN", - }; - - // Check ONNX domain (empty string) or MS domain (com.microsoft) - bool is_onnx_domain = (safe_domain[0] == '\0'); - bool is_ms_domain = (std::strcmp(safe_domain, "com.microsoft") == 0); - - if (is_onnx_domain && cuda_nhwc_onnx_ops.count(safe_op_type) > 0) { + if (cuda::IsNhwcEligible(safe_domain, safe_op_type)) { *should_convert = 1; // Convert - return nullptr; - } - - if (is_ms_domain && std::strcmp(safe_op_type, "GridSample") == 0) { - *should_convert = 1; // Convert - return nullptr; + } else { + *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. } - - *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. return nullptr; #endif } diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index c03545fc31435..99e669f73eb72 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -310,7 +310,13 @@ def make_bias_dropout_model(): def run_operator_test( - target_device, model_creator, inputs, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, session_config=None + target_device, + model_creator, + inputs, + expected_fn, + ep_name=CUDA_PLUGIN_EP_NAME, + session_config=None, + nhwc_ops=None, ): with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: model_path = tmp.name @@ -329,6 +335,10 @@ def run_operator_test( ) return False + # Structural assertion: verify NHWC domain assignment when requested + if nhwc_ops: + _assert_nhwc_domain_assigned(sess, ep_name, nhwc_ops) + print( f"(Session created with {active_providers}; assigned nodes: " f"{', '.join(_format_assigned_node(node) for node in assigned_nodes)})", @@ -407,6 +417,101 @@ def _expected_conv(inputs): _NHWC_CONFIG = {"ep.cuda.prefer_nhwc_layout": "1"} +def _assert_nhwc_domain_assigned(session, ep_name, expected_ops): + """Assert that NHWC layout transformation occurred for the expected ops. + + The framework's NHWC layout transformer rewrites eligible ops to the internal NHWC domain + and wraps them with Transpose nodes. We verify NHWC transformation by checking: + 1. If the assignment API surfaces NHWC-domain nodes, verify expected ops are present. + 2. Otherwise, fall back to checking that Transpose nodes were assigned (their presence + indicates the layout transformer ran and the NHWC kernel was found). + + Args: + session: An InferenceSession with graph assignment info enabled. + ep_name: Name of the EP to check (e.g., CUDA_PLUGIN_EP_NAME). + expected_ops: Set or list of op_type strings expected to have NHWC transformation. + + Returns: + True if evidence of NHWC transformation is found. Raises AssertionError otherwise. + """ + assigned_nodes, _ = _get_assigned_nodes(session, ep_name) + + # Check for NHWC-domain nodes directly (preferred when the API surfaces them). + nhwc_domain = "com.ms.internal.nhwc" + nhwc_ops_found = {n.op_type for n in assigned_nodes if n.domain == nhwc_domain} + if nhwc_ops_found: + missing = set(expected_ops) - nhwc_ops_found + if missing: + raise AssertionError( + f"Expected NHWC-domain nodes for {sorted(missing)} but only found " + f"{sorted(nhwc_ops_found)} in {ep_name} NHWC assignments." + ) + return True + + # Fallback: the NHWC transformation inserts Transpose nodes around the target op. + transpose_count = sum(1 for n in assigned_nodes if n.op_type == "Transpose") + if transpose_count == 0: + all_ops = [f"{n.domain or 'ai.onnx'}::{n.op_type}" for n in assigned_nodes] + raise AssertionError( + f"Expected NHWC layout transformation for {sorted(expected_ops)} but no Transpose " + f"nodes were found in {ep_name} assignments. Assigned ops: {all_ops}. " + f"This indicates the NHWC kernel was not found for the target op(s)." + ) + return True + + +def _run_nhwc_model_test(target_device, op_name, model, feed_dict, expected_fn, nhwc_ops=None, rtol=1e-3, atol=1e-3): + """Run an NHWC test: verify domain assignment and numerical correctness. + + Args: + target_device: EP device to test on. + op_name: Op type name (for display and default NHWC assertion). + model: ONNX model proto. + feed_dict: Input feed dictionary. + expected_fn: Function(feed_dict) -> expected output(s). + nhwc_ops: Set of op_types expected in NHWC domain (defaults to {op_name}). + rtol: Relative tolerance for output comparison. + atol: Absolute tolerance for output comparison. + + Returns: + TEST_PASS or TEST_FAIL string. + """ + if nhwc_ops is None: + nhwc_ops = {op_name} + with tempfile.NamedTemporaryFile(suffix=f"_{op_name}_nhwc.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = _create_session_options(_NHWC_CONFIG) + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + if not assigned_nodes: + print( + f"{TEST_FAIL} ({CUDA_PLUGIN_EP_NAME} was assigned no nodes; " + f"assignments={_format_assignment_summary(assignment_info)})" + ) + return TEST_FAIL + + # Structural assertion: verify NHWC domain assignment + _assert_nhwc_domain_assigned(sess, CUDA_PLUGIN_EP_NAME, nhwc_ops) + + res = sess.run(None, feed_dict) + expected = expected_fn(feed_dict) + if isinstance(expected, (list, tuple)): + for r, e in zip(res, expected, strict=True): + np.testing.assert_allclose(r, e, rtol=rtol, atol=atol) + else: + np.testing.assert_allclose(res[0], expected, rtol=rtol, atol=atol) + return TEST_PASS + except Exception as e: + print(f"{TEST_FAIL} ({e})") + return TEST_FAIL + finally: + if os.path.exists(model_path): + os.remove(model_path) + + def _expected_batchnorm(inputs): return inputs["X"] / np.sqrt(1.0 + 1e-5) @@ -589,7 +694,12 @@ def test_nhwc_conv(self): "W": np.random.rand(3, 2, 3, 3).astype(np.float32), } result = run_operator_test( - target_device, create_conv_model, inputs, _expected_conv, session_config=_NHWC_CONFIG + target_device, + create_conv_model, + inputs, + _expected_conv, + session_config=_NHWC_CONFIG, + nhwc_ops={"Conv"}, ) self.assertTrue(result, "Conv (NHWC) plugin test failed") @@ -597,7 +707,12 @@ def test_nhwc_batch_normalization(self): target_device = get_cuda_plugin_device() inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} result = run_operator_test( - target_device, create_batch_norm_model, inputs, _expected_batchnorm, session_config=_NHWC_CONFIG + target_device, + create_batch_norm_model, + inputs, + _expected_batchnorm, + session_config=_NHWC_CONFIG, + nhwc_ops={"BatchNormalization"}, ) self.assertTrue(result, "BatchNormalization (NHWC) plugin test failed") @@ -610,6 +725,7 @@ def test_nhwc_maxpool(self): inputs, lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), session_config=_NHWC_CONFIG, + nhwc_ops={"MaxPool"}, ) self.assertTrue(result, "MaxPool (NHWC) plugin test failed") @@ -622,9 +738,178 @@ def test_nhwc_avgpool(self): inputs, lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), session_config=_NHWC_CONFIG, + nhwc_ops={"AveragePool"}, ) self.assertTrue(result, "AveragePool (NHWC) plugin test failed") + def test_nhwc_conv_transpose(self): + target_device = get_cuda_plugin_device() + # ConvTranspose: input [1,2,4,4], weight [2,3,3,3] -> output [1,3,6,6] with stride=2, padding=1, output_padding=1 + f_dtype = TensorProto.FLOAT + node = helper.make_node( + "ConvTranspose", + ["X", "W"], + ["Y"], + strides=[2, 2], + pads=[1, 1, 1, 1], + output_padding=[1, 1], + group=1, + ) + graph = helper.make_graph( + [node], + "test-ConvTranspose", + [ + helper.make_tensor_value_info("X", f_dtype, [1, 2, 4, 4]), + helper.make_tensor_value_info("W", f_dtype, [2, 3, 3, 3]), + ], + [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 6, 6])], + ) + opset = OperatorSetIdProto() + opset.version = 11 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + w = np.random.rand(2, 3, 3, 3).astype(np.float32) + + def expected_fn(feed): + return F.conv_transpose2d( + torch.from_numpy(feed["X"]), + torch.from_numpy(feed["W"]), + stride=2, + padding=1, + output_padding=1, + ).numpy() + + result = _run_nhwc_model_test(target_device, "ConvTranspose", model, {"X": x, "W": w}, expected_fn) + self.assertEqual(result, TEST_PASS, "ConvTranspose (NHWC) plugin test failed") + + def test_nhwc_global_max_pool(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "GlobalMaxPool", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 1, 1])], + opset=12, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.adaptive_max_pool2d(t, output_size=1).numpy() + + result = _run_nhwc_model_test(target_device, "GlobalMaxPool", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "GlobalMaxPool (NHWC) plugin test failed") + + def test_nhwc_global_average_pool(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "GlobalAveragePool", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 1, 1])], + opset=12, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.adaptive_avg_pool2d(t, output_size=1).numpy() + + result = _run_nhwc_model_test(target_device, "GlobalAveragePool", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "GlobalAveragePool (NHWC) plugin test failed") + + def test_nhwc_depth_to_space(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # DepthToSpace: [1,4,2,2] -> [1,1,4,4] with blocksize=2 + model = _make_simple_model( + "DepthToSpace", + [("X", f_dtype, [1, 4, 2, 2])], + [("Y", f_dtype, [1, 1, 4, 4])], + attrs={"blocksize": 2, "mode": "DCR"}, + opset=13, + ) + x = np.random.rand(1, 4, 2, 2).astype(np.float32) + + def expected_fn(feed): + # DCR mode: depth, column, row + t = feed["X"] # [1, 4, 2, 2] + b = 2 + n, c, h, w = t.shape + t = t.reshape(n, b, b, c // (b * b), h, w) + t = t.transpose(0, 3, 4, 1, 5, 2) # [n, c/b^2, h, b, w, b] + return t.reshape(n, c // (b * b), h * b, w * b) + + result = _run_nhwc_model_test(target_device, "DepthToSpace", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "DepthToSpace (NHWC) plugin test failed") + + def test_nhwc_space_to_depth(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # SpaceToDepth: [1,1,4,4] -> [1,4,2,2] with blocksize=2 + model = _make_simple_model( + "SpaceToDepth", + [("X", f_dtype, [1, 1, 4, 4])], + [("Y", f_dtype, [1, 4, 2, 2])], + attrs={"blocksize": 2}, + opset=13, + ) + x = np.random.rand(1, 1, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = feed["X"] # [1, 1, 4, 4] + b = 2 + n, c, h, w = t.shape + t = t.reshape(n, c, h // b, b, w // b, b) + t = t.transpose(0, 3, 5, 1, 2, 4) # [n, b, b, c, h/b, w/b] + return t.reshape(n, c * b * b, h // b, w // b) + + result = _run_nhwc_model_test(target_device, "SpaceToDepth", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "SpaceToDepth (NHWC) plugin test failed") + + def test_nhwc_lrn(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # LRN: [1,3,4,4] with size=3, alpha=0.0001, beta=0.75, bias=1.0 + model = _make_simple_model( + "LRN", + [("X", f_dtype, [1, 3, 4, 4])], + [("Y", f_dtype, [1, 3, 4, 4])], + attrs={"size": 3, "alpha": 0.0001, "beta": 0.75, "bias": 1.0}, + opset=13, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + return F.local_response_norm(t, size=3, alpha=0.0001, beta=0.75, k=1.0).numpy() + + result = _run_nhwc_model_test(target_device, "LRN", model, {"X": x}, expected_fn) + self.assertEqual(result, TEST_PASS, "LRN (NHWC) plugin test failed") + + def test_nhwc_grid_sample(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + # GridSample: X [1,1,4,4], grid [1,3,3,2] -> Y [1,1,3,3] + model = _make_simple_model( + "GridSample", + [("X", f_dtype, [1, 1, 4, 4]), ("grid", f_dtype, [1, 3, 3, 2])], + [("Y", f_dtype, [1, 1, 3, 3])], + attrs={"mode": "linear", "padding_mode": "zeros", "align_corners": 0}, + opset=20, + ) + x = np.random.rand(1, 1, 4, 4).astype(np.float32) + # Grid values in [-1, 1] + grid = np.random.rand(1, 3, 3, 2).astype(np.float32) * 2 - 1 + + def expected_fn(feed): + t = torch.from_numpy(feed["X"]) + g = torch.from_numpy(feed["grid"]) + return F.grid_sample(t, g, mode="bilinear", padding_mode="zeros", align_corners=False).numpy() + + result = _run_nhwc_model_test(target_device, "GridSample", model, {"X": x, "grid": grid}, expected_fn) + self.assertEqual(result, TEST_PASS, "GridSample (NHWC) plugin test failed") + # ---- Standard op tests ---- def test_op_reshape(self):