diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 5f35e9ad10..d4bf1fd3a1 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -130,6 +130,8 @@ def check_group_quantization_nvfp4_versus_reference( [ # edge case, zero tokens for all (0, 512), + # edge case, not 128 multiple hidden dimension + (1024, 320), # full tile cases (256, 1024), (1024, 256), diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f8f793f036..89cd90f347 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1355,9 +1356,19 @@ std::vector split_quantize(const at::Tensor &tensor, for (auto &quantizer : quantizer_cpp_list) { nvfp4_quantizers.push_back(static_cast(quantizer.get())); } - bool contiguous_data_and_scale; + bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); + if (!input_shape.empty() && input_shape.back() % 128 != 0) { + static std::once_flag once_unfused_nvfp4_fallback_warning; + std::call_once(once_unfused_nvfp4_fallback_warning, []() { + NVTE_WARN( + "Unfused NVFP4 quantization fallback is triggered because the input tensor inner " + "dimension is not a multiple of 128, disabling NVFP4 grouped kernel fusion. " + "NVFP4 might bring performance regressions for this input tensor shape."); + }); + quantization_method = QuantizationMethod::UNFUSED; + } if (!contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED;