Skip to content

Update GatherBlockQuantized to support 2-bits#28530

Open
HectorSVC wants to merge 4 commits into
microsoft:mainfrom
HectorSVC:hecli_gather_2bits
Open

Update GatherBlockQuantized to support 2-bits#28530
HectorSVC wants to merge 4 commits into
microsoft:mainfrom
HectorSVC:hecli_gather_2bits

Conversation

@HectorSVC
Copy link
Copy Markdown
Contributor

Description

Update GatherBlockQuantized to support 2-bits.
Updated op schema, implemented the CPU and WebGPU EP.
This helps to make the model smaller.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends GatherBlockQuantized to support 2-bit uint8-packed data, updating schema/docs, CPU/WebGPU implementations, and tests.

Changes:

  • Adds 2-bit packing/dequantization support for CPU and WebGPU paths.
  • Updates contrib operator schema and generated docs to include 2-bit uint8 support.
  • Adds CPU and WebGPU tests for 2-bit uint8 no-zero-point cases.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc Adds CPU 2-bit uint8 extraction and default zero-point handling.
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc Adds WebGPU shader logic and shape handling for 2-bit packed uint8 data.
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h Expands WebGPU bits validation to include 2.
onnxruntime/core/graph/contrib_ops/contrib_defs.cc Updates schema documentation and zero-point shape inference for packed components.
docs/ContribOperators.md Updates public operator documentation for 2-bit support.
onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc Adds 2-bit packing helper logic and new CPU/WebGPU test cases.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc Outdated
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label May 18, 2026
@HectorSVC HectorSVC requested a review from Copilot May 18, 2026 22:00
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

The core zero-point indexing issues from the prior review iteration have been properly fixed in both CPU and WebGPU paths. The 2-bit extraction logic and the scale_row / q_in_row decomposition for zero-point addressing are correct.

Two suggestions:

  1. Inconsistent is_int8 vs is_uint8 guards — The 2-bit code paths in ComputeInternal mix is_int8 (which covers both INT8 and UINT8) and is_uint8 inconsistently. Since bits==2 is only valid for uint8, using is_uint8 throughout would make intent clearer.
  2. Missing test coverage for 2-bit zero_points — Both new tests only exercise the default-zero-point path. The 2-bit zero-point unpacking logic (the most complex new code) has no dedicated test.

// as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and
// the shader does explicit byte+bit-position extraction.
TensorShape x_shape;
if (bits_ == 2 && is_int8) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: This uses is_int8 (which covers both ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 and UINT8) but 2-bit packing only applies to uint8 storage. Since the validation above already rejects INT8 + bits==2, this is functionally equivalent to is_uint8, but using is_uint8 here (and in the zero-points block at line 287) would be more self-documenting and defensive.

Suggested change
if (bits_ == 2 && is_int8) {
if (bits_ == 2 && is_uint8) {

/*block_size=*/16, /*bits=*/2, output, output_shape);
}
#endif // USE_WEBGPU

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Consider adding a companion test with explicit 2-bit packed zero_points (4 per byte along the quantize axis). The zero-point row-boundary logic was the subject of the prior review iteration and is the most complex new code path, but currently lacks test coverage. A case where scale_qaxis_dim is not a multiple of 4 would be especially valuable for validating the packed byte addressing in both the CPU and WebGPU kernels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants