From fa2e2cb7aa161a1c65304ba7c35007c769659f48 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 20 Feb 2026 18:29:43 +0000 Subject: [PATCH 01/12] Enable sm120 support for fused attn if cuDNN is 9.18.1+ Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 567fd17c34..9646fed07e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,6 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + #TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False From bea8bbbdf061cf3075cb645c740b98333326bd6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:42:10 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9646fed07e..82d1d1b2a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,7 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False - #TODO: KL check if this condition is now supported or not ? + # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) From b2f5864b19ab5236aa4ffebb24fb081dbe187ab8 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Mon, 2 Mar 2026 15:25:48 -0800 Subject: [PATCH 03/12] Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 9 ++++++--- transformer_engine/common/transformer_engine.cpp | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..b5a12df803 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1102,6 +1102,9 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; @@ -1128,7 +1131,7 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1136,7 +1139,7 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1145,7 +1148,7 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index cd02074fbd..763b3a1673 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1130,6 +1130,7 @@ int nvte_is_non_tn_fp8_gemm_supported() { static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); + // TODO: KL check if this condition is now supported or not ? std::call_once(flags[device_id], [&]() { int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id); // Note: this is temporary restriction and should be lifted in the future. From 076420d69f7f815844547cf958a318dd4e72a02c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 23:32:31 +0000 Subject: [PATCH 04/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b5a12df803..7c55654228 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1131,7 +1131,8 @@ void fused_attn_arbitrary_seqlen_fwd( if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Max->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1139,7 +1140,8 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.dtype = DType::kFloat32; Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Sum_Exp->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1148,7 +1150,8 @@ void fused_attn_arbitrary_seqlen_fwd( } else { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && !(sm_arch_ >= 120)) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + !(sm_arch_ >= 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; From 8753fc2b8c47f6d0001923aaf61e7a55f3f6fd1c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 4 Mar 2026 09:12:54 -0800 Subject: [PATCH 05/12] Add support for sm120 correct batch, seq dims Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 7c55654228..a6375458ef 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -85,6 +85,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -92,6 +94,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } + /* // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { @@ -102,6 +105,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } +*/ + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // On SM 120+ (Blackwell), cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3] + // as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build + // so the check passes; ragged offset still provides variable-length boundaries. + if (sm_arch_ < 120) { + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } + } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; bool generate_stats = !return_max_logit; @@ -594,6 +613,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } + /* // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { @@ -604,7 +624,20 @@ void fused_attn_arbitrary_seqlen_bwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } - +*/ + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // On SM 120+ (Blackwell), cuDNN support check requires BHSD-like strides with max_seqlen (see fwd). + if (sm_arch_ < 120) { + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } + } // We choose between 32-bit and 64-bit offsets depending on need. // This allows us to support older cuDNN runtimes gracefully. const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; From 0336d2ae6d267367c82eab647dd743e5d5249675 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:12:03 +0000 Subject: [PATCH 06/12] Add support for sm120 BHS1 style max logit even QKV are THD to avoid incorrect max logit calculation (includes padded tokens in max calculation) Signed-off-by: Kshitij Lakhani --- .../pytorch/cpp_extensions/fused_attn.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..7d462db2ec 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -339,13 +339,22 @@ def fused_attn_fwd( if return_max_logit: qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] - # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] - # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] - # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] + # thd (older cuDNN runtimes or sm120): output_tensors: out [tq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] stats = output_tensors[1] + torch.log(output_tensors[2]) - amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) + max_tensor = output_tensors[1] + if qkv_format == "thd" and max_tensor.ndim == 4: + # For THD on older cuDNN runtimes or THD on sm120, stats can be [b, h, sq, 1] with padded + # sequence positions. Exclude those padded positions when computing max_logit. + seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(device=max_tensor.device) + sq_idx = torch.arange(max_tensor.shape[2], device=max_tensor.device).view(1, 1, -1, 1) + valid = sq_idx < seqlens_q.view(-1, 1, 1, 1) + max_tensor = max_tensor.masked_fill(~valid, float("-inf")) + amax_dims = (0, 2) if max_tensor.ndim == 3 else (0, 2, 3) # Max -> max_logit [h] - max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) + max_logit = torch.amax(max_tensor, dim=amax_dims).to(dtype=output_tensors[0].dtype) aux_ctx_tensors = [stats] aux_ctx_tensors.extend(output_tensors[3:]) return output_tensors[0], aux_ctx_tensors, max_logit From d24eb358ebbbb578cd5cbf8547b82c19396f42ab Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:22:38 +0000 Subject: [PATCH 07/12] Disable fused and flash attn for sm120 filter:kv cache Signed-off-by: Kshitij Lakhani --- .../attention/dot_product_attention/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 82d1d1b2a6..9478fb999a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -554,11 +554,15 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version - # until the cuDNN bug is resolved - if device_compute_capability == (8, 9): - logger.debug("Disabling FusedAttention for KV caching for sm89") + # Temporarily disabling fused attention for kv caching for sm89/sm120 irrespective of + # cuDNN version until the cuDNN bug is resolved. + if device_compute_capability in ((8, 9), (12, 0)): + logger.debug("Disabling FusedAttention for KV caching for sm89/sm120") use_fused_attention = False + # Temporarily disable FlashAttention for KV caching on sm120 + if device_compute_capability == (12, 0): + logger.debug("Disabling FlashAttention for KV caching for sm120") + use_flash_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") use_flash_attention = False @@ -610,7 +614,6 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False - # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) From e2e89d47907407371659909d92ced64976c2aeb2 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:25:16 +0000 Subject: [PATCH 08/12] For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+ Signed-off-by: Kshitij Lakhani --- .../attention/dot_product_attention/context_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index bd6b626b64..d8db7fb6e2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1491,7 +1491,10 @@ def forward( softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: - softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + softmax_lse_in_packed_format = ( + get_cudnn_version() >= (9, 6, 0) + and get_device_compute_capability() < (12, 0) + ) else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 From 5a8ecb96d07816fdbbbcf06fa6d36e5878ecfe86 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:31:26 +0000 Subject: [PATCH 09/12] Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sdpa arbitrary kernel call Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index abdce7fdac..08898612b5 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -545,6 +545,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) + // T3HD and TH3D are not supported by cuDNN on SM120 (Blackwell); assert before hitting the path. + const int device_id_fwd = cuda::current_device(); + const int sm_arch_fwd = cuda::sm_arch(device_id_fwd); + if (sm_arch_fwd >= 120 && + (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { + NVTE_ERROR( + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 (Blackwell). " + "Use thd_thd_thd or other THD layouts instead."); + } fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, @@ -644,6 +653,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) + // T3HD and TH3D are not supported by cuDNN on SM120 (Blackwell); assert before hitting the path. + const int device_id_bwd = cuda::current_device(); + const int sm_arch_bwd = cuda::sm_arch(device_id_bwd); + if (sm_arch_bwd >= 120 && + (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { + NVTE_ERROR( + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 (Blackwell). " + "Use thd_thd_thd or other THD layouts instead."); + } size_t i = 0; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); From 7ca7564900f5135dc82088b52d683a80fbd366bb Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:48:00 +0000 Subject: [PATCH 10/12] Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also include a check for sm120 Signed-off-by: Kshitij Lakhani --- .../fused_attn_f16_arbitrary_seqlen.cu | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index a6375458ef..8fd814ca70 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -87,6 +87,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -355,7 +356,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } std::shared_ptr Max, Sum_Exp; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -372,7 +373,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_name("Sum_Exp") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { @@ -400,7 +401,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( if (!return_max_logit) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { Stats->set_stride({h * s_q, s_q, 1, 1}); @@ -426,7 +427,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) + auto offset_s_tuple = use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -462,7 +463,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; @@ -529,7 +530,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { devOffsetsS = static_cast(devOffsets) + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; @@ -548,7 +549,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; } - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { variant_pack[offset_stats] = devOffsetsS; } } @@ -606,6 +607,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && !(sm_arch_ >= 120); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); @@ -798,7 +800,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_name("stats") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") @@ -824,7 +826,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { sdpa_backward_options.set_max_total_seq_len_q(s_q); } if (is_ragged_kv && cudnn_runtime_version >= 90600) { @@ -947,7 +949,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) + auto offset_s_tuple = use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -982,7 +984,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( size_t seqlen_offsets_workspace_size = 0; if (is_ragged_q || is_ragged_kv) { size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; @@ -1052,7 +1054,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; } void *devOffsetsS = nullptr; - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { devOffsetsS = static_cast(devOffsets) + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * num_bytes_per_ragged_offset; @@ -1071,7 +1073,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; } - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { variant_pack[offset_stats] = devOffsetsS; } } From 3227016a2ae02859b455d9f49a81bc655e72313e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 21:50:38 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 10 ++++------ .../dot_product_attention/context_parallel.py | 9 +++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 8fd814ca70..dff0e8e287 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -427,9 +427,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = use_ragged_stats - ? std::make_tuple(offset_stats) - : std::make_tuple(nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -949,9 +948,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); auto offset_kv_tuple = is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); - auto offset_s_tuple = use_ragged_stats - ? std::make_tuple(offset_stats) - : std::make_tuple(nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d8db7fb6e2..69681104ce 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1491,10 +1491,11 @@ def forward( softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: - softmax_lse_in_packed_format = ( - get_cudnn_version() >= (9, 6, 0) - and get_device_compute_capability() < (12, 0) - ) + softmax_lse_in_packed_format = get_cudnn_version() >= ( + 9, + 6, + 0, + ) and get_device_compute_capability() < (12, 0) else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 From 46c6e60c770b0886a78bd4d261d6b1b2d4e670ab Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Wed, 11 Mar 2026 21:58:35 +0000 Subject: [PATCH 12/12] nit: Code clean up Signed-off-by: Kshitij Lakhani --- .../common/fused_attn/fused_attn.cpp | 8 +++--- .../fused_attn_f16_arbitrary_seqlen.cu | 28 ++----------------- .../common/transformer_engine.cpp | 1 - 3 files changed, 6 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 08898612b5..33f76de3f5 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -545,13 +545,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - // T3HD and TH3D are not supported by cuDNN on SM120 (Blackwell); assert before hitting the path. + // T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path. const int device_id_fwd = cuda::current_device(); const int sm_arch_fwd = cuda::sm_arch(device_id_fwd); if (sm_arch_fwd >= 120 && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { NVTE_ERROR( - "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 (Blackwell). " + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " "Use thd_thd_thd or other THD layouts instead."); } fused_attn_arbitrary_seqlen_fwd( @@ -653,13 +653,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - // T3HD and TH3D are not supported by cuDNN on SM120 (Blackwell); assert before hitting the path. + // T3HD and TH3D are not supported by cuDNN on SM120; assert before hitting the path. const int device_id_bwd = cuda::current_device(); const int sm_arch_bwd = cuda::sm_arch(device_id_bwd); if (sm_arch_bwd >= 120 && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D)) { NVTE_ERROR( - "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 (Blackwell). " + "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. " "Use thd_thd_thd or other THD layouts instead."); } size_t i = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index dff0e8e287..044d3874ac 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -95,23 +95,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } - /* // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // replace batch size and maximum sequence lengths with maximum token counts - // for query and key/value so the graph is static within each quantization bucket - b = max_b; - s_q = is_ragged_q ? max_t_q : s_q; - s_kv = is_ragged_kv ? max_t_kv : s_kv; - } -*/ - // keep original batch size because cu_seqlens are created with [b+1] shape - int64_t actual_b = b; - if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { - NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // On SM 120+ (Blackwell), cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3] + // On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3] // as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build // so the check passes; ragged offset still provides variable-length boundaries. if (sm_arch_ < 120) { @@ -614,23 +602,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); } - /* - // keep original batch size because cu_seqlens are created with [b+1] shape - int64_t actual_b = b; - if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { - NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // replace batch size and maximum sequence lengths with maximum token counts - // for query and key/value so the graph is static within each quantization bucket - b = max_b; - s_q = is_ragged_q ? max_t_q : s_q; - s_kv = is_ragged_kv ? max_t_kv : s_kv; - } -*/ // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); - // On SM 120+ (Blackwell), cuDNN support check requires BHSD-like strides with max_seqlen (see fwd). + // On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd). if (sm_arch_ < 120) { // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 763b3a1673..cd02074fbd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1130,7 +1130,6 @@ int nvte_is_non_tn_fp8_gemm_supported() { static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); - // TODO: KL check if this condition is now supported or not ? std::call_once(flags[device_id], [&]() { int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id); // Note: this is temporary restriction and should be lifted in the future.