From 6bfeff31e17fdbc21e71d6e384710ac66285f919 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Sat, 20 Dec 2025 00:54:50 +0000 Subject: [PATCH 1/9] Plumbing correct bias dims from TE to cudnn Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 34 +++++++++++++------ .../common/fused_attn/fused_attn_fp8.cu | 20 +++++++---- transformer_engine/common/fused_attn/utils.h | 6 ++-- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index d3746fc042..354c376002 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, @@ -120,6 +120,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( max_pages_per_seq_v, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -263,8 +265,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -539,7 +541,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, @@ -612,6 +614,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -794,12 +798,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_bias) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("bias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); dBias = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dBias") - .set_dim({bias_b, bias_h, s_q, s_kv}) - .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1064,10 +1068,14 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } void *devPtrSoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -1133,7 +1141,7 @@ void fused_attn_arbitrary_seqlen_fwd( if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.shape = {bias_b, bias_h, bias_sq, bias_skv}; output_bias->data.dtype = QKV_type; } @@ -1178,7 +1186,7 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, @@ -1224,11 +1232,15 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; + size_t bias_sq = 0; + size_t bias_skv = 0; if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { devPtrBias = input_Bias->data.dptr; devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; + bias_sq = input_Bias->data.shape[2]; + bias_skv = input_Bias->data.shape[3]; } size_t max_batch_size = 0; @@ -1271,7 +1283,7 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3630041ccf..32ef6f86ee 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1671,6 +1671,8 @@ void fused_attn_fp8_fwd_impl_v1( bool is_dropout = (is_training && dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -1697,6 +1699,8 @@ void fused_attn_fp8_fwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, is_training, dropout_probability, @@ -1817,8 +1821,8 @@ void fused_attn_fp8_fwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_options.set_bias(bias); // } @@ -1998,6 +2002,8 @@ void fused_attn_fp8_bwd_impl_v1( bool is_dropout = (dropout_probability != 0.0f); auto bias_b = b; auto bias_h = h; + auto bias_sq = s_q; + auto bias_skv = s_kv; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || @@ -2026,6 +2032,8 @@ void fused_attn_fp8_bwd_impl_v1( 0, bias_b, bias_h, + bias_sq, + bias_skv, scaling_factor, true, dropout_probability, @@ -2192,12 +2200,12 @@ void fused_attn_fp8_bwd_impl_v1( // if (is_bias) { // bias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("bias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() // .set_name("dBias") - // .set_dim({bias_b, bias_h, s_q, s_kv}) - // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + // .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); // sdpa_backward_options.set_bias(bias); // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // // are not supported for dbias calculation but they are diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 7d23bb5c55..4e372a56d0 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -101,6 +101,8 @@ struct FADescriptor_v1 { std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; + std::int64_t bias_sq; + std::int64_t bias_skv; float attnScale; bool isTraining; float dropoutProbability; @@ -119,13 +121,13 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, From 48641f18b977b8cd62c73f2b235b28c33da4a12a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:22:38 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn_f16_arbitrary_seqlen.cu | 81 ++++++++++--------- transformer_engine/common/fused_attn/utils.h | 19 ++--- 2 files changed, 53 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 354c376002..86530a8d01 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -52,15 +52,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, - int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, bool is_training, - bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, - void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, + bool is_training, bool return_max_logit, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -263,10 +264,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_options.set_bias(bias); } @@ -541,16 +543,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t bias_sq, int64_t bias_skv, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, - void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, - void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, - void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, - void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + int64_t bias_sq, int64_t bias_skv, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose, + void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -796,14 +799,16 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_alibi_mask(is_alibi); if (is_bias) { - bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); - dBias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dBias") - .set_dim({bias_b, bias_h, bias_sq, bias_skv}) - .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + bias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); + dBias = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("dBias") + .set_dim({bias_b, bias_h, bias_sq, bias_skv}) + .set_stride({bias_h * bias_sq * bias_skv, bias_sq * bias_skv, bias_skv, 1})); sdpa_backward_options.set_bias(bias); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] // are not supported for dbias calculation but they are @@ -1186,9 +1191,9 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, + is_training, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, @@ -1283,10 +1288,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, bias_sq, bias_skv, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 4e372a56d0..2e4dafd973 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -121,17 +121,18 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, - attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, + bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, bias_type, + qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type, + generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, - rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, + rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, + rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, + rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; From a9645dabb8a5947567ab22926581b2a719573964 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 00:27:48 +0000 Subject: [PATCH 3/9] Make changes for cp bias code Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 5 ++--- .../pytorch/attention/dot_product_attention/utils.py | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 9111d3511c..1060a966a5 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -515,7 +515,7 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: ModelConfig(b, sq, hq, dqk) - "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="111s"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), @@ -1131,11 +1131,10 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) + shape = shape.replace("_1_s", "_1_skv") shape = shape.replace("_s_s", "_sq_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") - if config.bias_shape != "1hss": - bias.requires_grad = False # Create RNG _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index cb74a15e77..451db2ccd6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -958,12 +958,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt and fu_core_attention_bias_type == "post_scale_bias" and fu_core_attention_bias_shape != "1hss" ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") - use_fused_attention = False - else: + if not fu_core_attention_bias_requires_grad: # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" From 3996cee469aa612c17c63db16f18deadb655e4ac Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 02:19:48 +0000 Subject: [PATCH 4/9] Add dbias and dbias_ to run_dpa_with_cp test Signed-off-by: Kshitij Lakhani --- .../attention/run_attention_with_cp.py | 174 +++++++++++++----- 1 file changed, 126 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3efb516b57..49ad88be1d 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -307,6 +307,7 @@ def run_dpa_with_cp( if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + bias.requires_grad = True else: bias = None @@ -338,7 +339,7 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) - dq, dk, dv = q.grad, k.grad, v.grad + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad d_softmax_offset = None if config.softmax_type != "vanilla": d_softmax_offset = core_attn.softmax_offset.grad @@ -394,6 +395,7 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + bias_.requires_grad = True # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, @@ -433,23 +435,23 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() # get outputs - tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: - tensors[0], tensors[4] = tensors_to_deq + tensors[0], tensors[5] = tensors_to_deq for tensor in tensors: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) - out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": @@ -467,6 +469,26 @@ def run_dpa_with_cp( x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) for x in [dq_, dk_, dv_, out_] ] + if dbias is not None and dbias_ is not None: + dbias = dbias.view( + dbias.shape[0], + dbias.shape[1], + 2 * world_size, + dbias.shape[2] // (2 * world_size), + dbias.shape[3] + ) + # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) + dbias = dbias.index_select(2, seq_idx) + # Flatten + dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) + dbias_ = dbias_.view( + dbias_.shape[0], + dbias_.shape[1], + 2, + dbias_.shape[2] // 2, + dbias_.shape[3] + ) + elif qkv_format == "thd": dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] @@ -509,9 +531,9 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" @@ -519,47 +541,103 @@ def run_dpa_with_cp( if t is not None: if "softmax_offset" not in names[i] and "max_logit" not in names[i]: if qkv_format == "bshd": - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "sbhd": - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # After reshaping: (1, 1, 2, seq_q//2, seq_kv) + # Compare along dimension 2 (the split sequence dimension) + compare_and_assert( + t[:, :, 0], # First sequence chunk + tensors_cp[i][:, :, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, :, 1], # Second sequence chunk + tensors_cp[i][:, :, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) elif qkv_format == "thd": compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 From 094c40ac7d3c648b0deed0c22c8ee868db5b856f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 02:22:04 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/run_attention_with_cp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 49ad88be1d..2ae5937a14 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -475,18 +475,14 @@ def run_dpa_with_cp( dbias.shape[1], 2 * world_size, dbias.shape[2] // (2 * world_size), - dbias.shape[3] + dbias.shape[3], ) # bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv) dbias = dbias.index_select(2, seq_idx) # Flatten dbias = dbias.view(dbias.shape[0], dbias.shape[1], -1, dbias.shape[-1]) dbias_ = dbias_.view( - dbias_.shape[0], - dbias_.shape[1], - 2, - dbias_.shape[2] // 2, - dbias_.shape[3] + dbias_.shape[0], dbias_.shape[1], 2, dbias_.shape[2] // 2, dbias_.shape[3] ) elif qkv_format == "thd": From a085f9e32216ca5240b923629e6f374764f3093e Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 9 Jan 2026 02:27:06 +0000 Subject: [PATCH 6/9] Fix: Use output_dBias instead of input_dBias to extract the shape Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 86530a8d01..0d8dbe7191 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1244,8 +1244,8 @@ void fused_attn_arbitrary_seqlen_bwd( devPtrdBias = output_dBias->data.dptr; bias_b = output_dBias->data.shape[0]; bias_h = output_dBias->data.shape[1]; - bias_sq = input_Bias->data.shape[2]; - bias_skv = input_Bias->data.shape[3]; + bias_sq = output_dBias->data.shape[2]; + bias_skv = output_dBias->data.shape[3]; } size_t max_batch_size = 0; From de3011e4a8a0256fb6bfd8efc43c33159a37ceed Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 21 Jan 2026 19:31:44 +0000 Subject: [PATCH 7/9] Add guards for bias/bias_/dbias/dbias_ being None Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/run_attention_with_cp.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 2ae5937a14..baf48bc407 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -339,7 +339,7 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None d_softmax_offset = None if config.softmax_type != "vanilla": d_softmax_offset = core_attn.softmax_offset.grad @@ -435,7 +435,7 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad + dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() @@ -445,12 +445,16 @@ def run_dpa_with_cp( if fp8_mha: tensors_to_deq = [out, out_] if not fp8_bwd else tensors for i, tensor in enumerate(tensors_to_deq): - tensors_to_deq[i] = tensor.dequantize() + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq for tensor in tensors: - assert torch.all(~torch.isnan(tensor)) - assert torch.all(~torch.isinf(tensor)) + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ From 3d9f24b5dcba391380ed6c0005e7c7fad06f977c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Thu, 22 Jan 2026 00:47:25 +0000 Subject: [PATCH 8/9] Add support for bias shape 111s in addition to the original 1hss, 11ss, b1ss and bhss Signed-off-by: Kshitij Lakhani --- tests/pytorch/attention/test_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 1060a966a5..08e03d24e1 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1131,8 +1131,12 @@ def _run_dot_product_attention( bias = None if config.attn_bias_type == "post_scale_bias": shape = "_".join(config.bias_shape) - shape = shape.replace("_1_s", "_1_skv") + # For 1hss, 11ss, b1ss, bhss + shape_cache = shape shape = shape.replace("_s_s", "_sq_skv") + if shape==shape_cache: + # For 111s + shape = shape.replace("_1_s", "_1_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")] bias = torch.randn(tensor_shape, dtype=dtype, device="cuda") From 35b06db2a9e4d2cacc8d9af42c4064080928e5e4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 00:48:27 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 08e03d24e1..2ea5e97134 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1134,7 +1134,7 @@ def _run_dot_product_attention( # For 1hss, 11ss, b1ss, bhss shape_cache = shape shape = shape.replace("_s_s", "_sq_skv") - if shape==shape_cache: + if shape == shape_cache: # For 111s shape = shape.replace("_1_s", "_1_skv") tensor_shape = [dim_to_num[j] for j in shape.split("_")]