Update GatherBlockQuantized to support 2-bits#28530
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends GatherBlockQuantized to support 2-bit uint8-packed data, updating schema/docs, CPU/WebGPU implementations, and tests.
Changes:
- Adds 2-bit packing/dequantization support for CPU and WebGPU paths.
- Updates contrib operator schema and generated docs to include 2-bit uint8 support.
- Adds CPU and WebGPU tests for 2-bit uint8 no-zero-point cases.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc |
Adds CPU 2-bit uint8 extraction and default zero-point handling. |
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc |
Adds WebGPU shader logic and shape handling for 2-bit packed uint8 data. |
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h |
Expands WebGPU bits validation to include 2. |
onnxruntime/core/graph/contrib_ops/contrib_defs.cc |
Updates schema documentation and zero-point shape inference for packed components. |
docs/ContribOperators.md |
Updates public operator documentation for 2-bit support. |
onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc |
Adds 2-bit packing helper logic and new CPU/WebGPU test cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
Review Summary
The core zero-point indexing issues from the prior review iteration have been properly fixed in both CPU and WebGPU paths. The 2-bit extraction logic and the scale_row / q_in_row decomposition for zero-point addressing are correct.
Two suggestions:
- Inconsistent
is_int8vsis_uint8guards — The 2-bit code paths inComputeInternalmixis_int8(which covers both INT8 and UINT8) andis_uint8inconsistently. Since bits==2 is only valid for uint8, usingis_uint8throughout would make intent clearer. - Missing test coverage for 2-bit zero_points — Both new tests only exercise the default-zero-point path. The 2-bit zero-point unpacking logic (the most complex new code) has no dedicated test.
| // as the input_shape uniform. The buffer remains the original uint8 storage with Flatten=4, and | ||
| // the shader does explicit byte+bit-position extraction. | ||
| TensorShape x_shape; | ||
| if (bits_ == 2 && is_int8) { |
There was a problem hiding this comment.
Suggestion: This uses is_int8 (which covers both ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 and UINT8) but 2-bit packing only applies to uint8 storage. Since the validation above already rejects INT8 + bits==2, this is functionally equivalent to is_uint8, but using is_uint8 here (and in the zero-points block at line 287) would be more self-documenting and defensive.
| if (bits_ == 2 && is_int8) { | |
| if (bits_ == 2 && is_uint8) { |
| /*block_size=*/16, /*bits=*/2, output, output_shape); | ||
| } | ||
| #endif // USE_WEBGPU | ||
|
|
There was a problem hiding this comment.
Suggestion: Consider adding a companion test with explicit 2-bit packed zero_points (4 per byte along the quantize axis). The zero-point row-boundary logic was the subject of the prior review iteration and is the most complex new code path, but currently lacks test coverage. A case where scale_qaxis_dim is not a multiple of 4 would be especially valuable for validating the packed byte addressing in both the CPU and WebGPU kernels.
Description
Update GatherBlockQuantized to support 2-bits.
Updated op schema, implemented the CPU and WebGPU EP.
This helps to make the model smaller.