Skip to content

Feat/cp nvshmem enhanced#2737

Open
Knight-of-Thunder wants to merge 24 commits intoNVIDIA:mainfrom
ETOgaosion:feat/cp_nvshmem_enhanced
Open

Feat/cp nvshmem enhanced#2737
Knight-of-Thunder wants to merge 24 commits intoNVIDIA:mainfrom
ETOgaosion:feat/cp_nvshmem_enhanced

Conversation

@Knight-of-Thunder
Copy link

@Knight-of-Thunder Knight-of-Thunder commented Mar 5, 2026

Description

To make the computation and communication become overlap, we create a new stream for communication.
To use NVSHMEM APIs easier, and make the code cleaner, we use NVSHMEM pybindings to replace original cpp code.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ptrendx and others added 21 commits August 18, 2025 16:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…e global_mesh_resource().fsdp_resource (NVIDIA#2088)

* Enforce global MeshResource is set

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Use global_mesh_resource().fsdp_resource in gemm primitive

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update tests

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update gemm.py

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

---------

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…A#2092)

Avoid garbage collection when capturing a CUDA Graph

Signed-off-by: Tim Moon <tmoon@nvidia.com>
* Fix incorrect version checks for atomic GEMM

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix typo

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* added cp strategy arg to DPA api

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

* converted DPA cp_strategy to string

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>

---------

Signed-off-by: Md Fahim Faysal Khan <mdfahimfaysa@nvidia.com>
* Return dummy wgrad tensors when requested by Mcore

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Apply suggestions from code review

Co-authored-by: Jan Bielak <janekb04@icloud.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Jan Bielak <janekb04@icloud.com>
* added shardy warning

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>


---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Revert "[Common] PDL for Quantization Kernels (NVIDIA#2001)"

This reverts commit bfab8c6.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
* Bump cuDNN FE to 1.14.0

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Change submodule hash

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Pick up a cuDNN FE fix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* New model configs in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Exclude cuDNN backend for some configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Revert "[Common] PDL for Blockwise Quantization (NVIDIA#2066)"

This reverts commit ebca615.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…#2083)

* code drop

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Pick up cuBLASMp during build

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* A fix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More test config

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Skip GemmAr test on unsupported HW

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Comment API

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update transformer_engine/common/include/transformer_engine/comm_gemm.h

Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
…kv caching (NVIDIA#2121)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
* disable determinism for sm100+ and cudnn<9.14

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert some changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
dev-base initialize: with version change and log verify
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces two new Python modules (context_parallel_nvshmem.py and context_parallel_nvshmem_enhanced.py) that replace raw C++ NVSHMEM calls with NVSHMEM pybindings to overlap computation and communication in Context Parallel attention. It also adds a new comm_gemm C++ backend that wraps cuBLASMp for fused AllGather/GEMM, GEMM/ReduceScatter, and GEMM/AllReduce operations, plus several quality-of-life fixes across the linear ops and CUDA graph support.

Key issues found:

  • nvshmem_kv NameError in backward (context_parallel_nvshmem_enhanced.py ~line 2089): nvshmem_kv is allocated and freed in the forward pass and never saved to ctx. The backward method references it directly, resulting in NameError on every backward call.
  • Duplicate function definition (context_parallel_nvshmem_enhanced.py ~lines 105–148): torchrun_uid_init_bcast_object_no_reinit is defined twice; the second copy silently shadows the first and contains Chinese-language development comments that indicate this is an unclean artifact from development.
  • nvshmem_malloc return not checked (comm_gemm.cpp line ~431): The symmetric heap allocation for the cuBLASMp workspace is used without a null-check, risking a null-pointer dereference if the heap is exhausted.
  • Debug print in setup.py: print("CMAKE_FLAGS:", ...) will be emitted on every NVTE_WITH_CUBLASMP=1 build.
  • Numerous previously-flagged issues remain: inverted error message in nvshmem_comm.cpp, missing cross-stream synchronization for NVSHMEM gets, unconditional re-initialization of NVSHMEM on every forward pass, unused tensor_get_buffer import, and cp_global_ranks NameError in the enhanced backward path.

Confidence Score: 1/5

  • Not safe to merge — multiple NameError crashes in the forward and backward passes of both new files.
  • The two primary new files (context_parallel_nvshmem_enhanced.py and context_parallel_nvshmem.py) contain several confirmed runtime crashes: a NameError for nvshmem_kv in the backward pass, a NameError for cp_global_ranks in the backward pass (previously flagged), a duplicate function definition with Chinese development comments indicating the code is not production-ready, missing NVSHMEM re-initialization guards, and a null-pointer risk in the C++ workspace allocation. The non-NVSHMEM changes (graph GC workaround, Megatron-LM wgrad fusion, comm_gemm backend) appear solid, but the core new feature is not ready.
  • transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem_enhanced.py and context_parallel_nvshmem.py require the most attention — they contain the majority of the critical bugs. transformer_engine/common/comm_gemm/comm_gemm.cpp needs the nvshmem_malloc null-check.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem_enhanced.py New 3223-line file implementing enhanced NVSHMEM context parallelism. Contains multiple critical bugs: duplicate function definition, nvshmem_kv used in backward but freed in forward without ctx save, and cp_global_ranks NameError in backward.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel_nvshmem.py New 4096-line file implementing NVSHMEM-based context parallelism. Contains unused import (tensor_get_buffer) and unconditional NVSHMEM re-initialization on every forward pass.
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp Adds nvshmem_get_on_current_stream C++ binding. The #else error message is inverted — it fires when NVTE_ENABLE_NVSHMEM is NOT defined but states the opposite condition.
transformer_engine/common/comm_gemm/comm_gemm.cpp New 519-line file implementing cuBLASMp-based comm+GEMM fusion (AG-GEMM, GEMM-RS, GEMM-AR). Workspace is allocated via nvshmem_malloc without checking for null return. Otherwise well-structured.
setup.py Adds cuBLASMp/NVSHMEM cmake flags. Contains a debug print statement that should be removed before merging.
transformer_engine/pytorch/graph.py Adds _graph_context_wrapper to temporarily disable GC during CUDA graph capture as a workaround for a PyTorch bug. Clean, well-documented change.
transformer_engine/pytorch/ops/basic/basic_linear.py Improves Megatron-LM wgrad fusion to set grad_added_to_main_grad flag and return a dummy wgrad tensor when needed. Minor refactoring to use weight_param local variable.

Sequence Diagram

sequenceDiagram
    participant FWD as Forward Pass
    participant CTX as ctx (saved state)
    participant BWD as Backward Pass
    participant NVSHMEM as NVSHMEM Heap
    participant P2P as P2P Fallback

    FWD->>NVSHMEM: nvshmem.tensor() → nvshmem_kv
    FWD->>NVSHMEM: nvshmem.tensor() → nvshmem_q

    loop cp_size iterations
        FWD->>NVSHMEM: nvshmem_get_on_stream(p2p_comm_buffers[i+1], nvshmem_kv, owner) [communicate_stream]
        Note over FWD: ⚠ No sync between communicate_stream and flash_attn_streams
        FWD->>FWD: flash attention compute on flash_attn_streams[i%2]
    end

    FWD->>CTX: ctx.cp_global_ranks = cp_global_ranks
    FWD->>NVSHMEM: nvshmem.free_tensor(nvshmem_kv)
    Note over NVSHMEM: ⚠ nvshmem_kv freed but NOT saved to ctx

    BWD->>CTX: read ctx.cp_global_ranks ✓
    BWD->>BWD: if nvshmem_kv is not None → NameError ✗
    Note over BWD: ⚠ nvshmem_kv undefined in backward scope

    alt NVSHMEM path
        BWD->>NVSHMEM: tex.nvshmem_get_on_current_stream(recv, nvshmem_kv, owner)
    else P2P fallback
        BWD->>P2P: flash_attn_p2p_communicate(...)
    end
Loading

Last reviewed commit: 1c88f8f

Comment on lines +699 to +720
def _store_fa_nvshmem(out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
) for _ in range(cp_size)]
except Exception:
nvshmem_fa_out = [None for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [None for _ in range(cp_size)]
# if nvshmem_fa_out is not None:
# for idx in range(cp_size):
# if nvshmem_fa_out[idx] is not None:
# nvshmem_fa_out[idx].copy_(out_tensor)
# nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)

# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function signature mismatch causes TypeError at runtime

The function _store_fa_nvshmem is defined with 2 parameters but called with 3 positional arguments at lines 932 and 2999:

# Definition (line 699): 2 params
def _store_fa_nvshmem(out_tensor, softmax_tensor):
    ...

# Call sites: 3 args
_store_fa_nvshmem(i, out_per_step[i], softmax_lse_per_step[i])  # lines 932, 2999

This will raise TypeError: _store_fa_nvshmem() takes 2 positional arguments but 3 were given at runtime.

The commented-out version at line 682 had the correct 3-parameter signature. Restore it with the idx parameter:

Suggested change
def _store_fa_nvshmem(out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out = [tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype) for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
) for _ in range(cp_size)]
except Exception:
nvshmem_fa_out = [None for _ in range(cp_size)]
nvshmem_fa_softmax_lse = [None for _ in range(cp_size)]
# if nvshmem_fa_out is not None:
# for idx in range(cp_size):
# if nvshmem_fa_out[idx] is not None:
# nvshmem_fa_out[idx].copy_(out_tensor)
# nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)
# create two streams to resolve wave quantization issue of Flash Attn in each step
flash_attn_streams = [torch.cuda.current_stream(), cp_stream]
# synchronize fwd results correction across steps
fwd_results_correction_done = torch.cuda.Event()
def _store_fa_nvshmem(idx, out_tensor, softmax_tensor):
# allocate symmetric tensors lazily and copy
if not causal:
return
try:
nvshmem_fa_out[idx] = tex.create_nvshmem_tensor(list(out_tensor.shape), out_tensor.dtype)
nvshmem_fa_softmax_lse[idx] = tex.create_nvshmem_tensor(
list(softmax_tensor.shape), softmax_tensor.dtype
)
except Exception:
nvshmem_fa_out[idx] = None
nvshmem_fa_softmax_lse[idx] = None
if nvshmem_fa_out[idx] is not None:
nvshmem_fa_out[idx].copy_(out_tensor)
nvshmem_fa_softmax_lse[idx].copy_(softmax_tensor)

Comment on lines +1590 to +1610
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a

cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)

q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined variable cp_global_ranks in backward

The backward function accesses cp_global_ranks at lines 1805 and 1835 as a bare local variable, but this variable is never assigned locally. The code only saves and retrieves ctx.cp_global_ranks. At runtime, this will raise NameError: name 'cp_global_ranks' is not defined.

# Line 1598-1599: correct usage with ctx prefix
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]

# Line 1805, 1835: incorrect usage without ctx prefix
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]  # NameError

Add this assignment near the top of the backward method alongside the other context unpacking:

Suggested change
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
)
cu_seqlens_q_per_step = other_tensors[:cp_size]
cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2]
rng_states = other_tensors[cp_size * 2 : cp_size * 3]
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a
cp_global_ranks = ctx.cp_global_ranks

Comment on lines +98 to +102
#else
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inverted error message in #else branch

The #else branch executes when NVTE_ENABLE_NVSHMEM is not defined, but the error message incorrectly states the function "cannot be initialized… when TE is compiled with NVTE_ENABLE_NVSHMEM=1" — the exact opposite condition.

Suggested change
#else
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
NVTE_ERROR(
"Internal TE error: nvshmem_get_on_current_stream requires TE to be compiled with ",
"NVTE_ENABLE_NVSHMEM=1!");

Comment on lines +752 to +776
with torch.cuda.stream(flash_attn_streams[i % 2]):
# wait until KV is received
for req in send_recv_reqs[(i + 1) % 2]:
req.wait()

if i < (cp_size - 1):
p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
if nvshmem_kv is not None:
# Use NVSHMEM get: compute owner of the (i+1)-th step KV block
owner_idx = (rank - (i + 1)) % cp_size
# Map owner idx to global rank (accounting for a2a groups)
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]
# nvshmem_get: dst (local buffer), src (symmetric address), peer=owner_global
tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global))
else:
# fallback to P2P if NVSHMEM not available
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
rank,
p2p_comm_buffers[i],
send_dst,
p2p_comm_buffers[i + 1],
recv_src,
cp_group,
batch_p2p_comm,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cross-stream synchronization for NVSHMEM gets

In the forward loop, the NVSHMEM get for p2p_comm_buffers[i+1] is issued on stream flash_attn_streams[i % 2] (line 765). In the next iteration, the computation consumes p2p_comm_buffers[i+1] on stream flash_attn_streams[(i+1) % 2] — a different stream — with no event-based synchronization.

The existing P2P fallback updates send_recv_reqs with requests that are waited on (line 754-755), but when the NVSHMEM path is taken, send_recv_reqs[(i) % 2] remains empty, so the wait is a no-op:

# Line 765: NVSHMEM get on stream[i % 2]
tex.nvshmem_get_on_current_stream(p2p_comm_buffers[i + 1], nvshmem_kv, int(owner_global))
# send_recv_reqs[(i) % 2] is never populated with a request

# Next iteration (i+1):
# Line 754-755: wait is a no-op because send_recv_reqs[(i+1) % 2] is empty
for req in send_recv_reqs[(i + 1) % 2]:
    req.wait()
# Line 779-782: computation accesses p2p_comm_buffers[i+1] on stream[(i+1) % 2]

Without synchronization, the computation kernel on stream [(i+1) % 2] may observe stale data. Add CUDA event-based synchronization after the NVSHMEM get and before subsequent iterations use the buffer.

Comment on lines +429 to +465
def torchrun_uid_init_bcast_object_no_reinit(cp_group=None):
local_rank = torch.cuda.current_device()
dev = Device(local_rank)
dev.set_current()

if cp_group is None:
rank_id = dist.get_rank()
num_ranks = dist.get_world_size()
else:
rank_id = dist.get_rank(group=cp_group)
num_ranks = dist.get_world_size(group=cp_group)

uniqueid = nvshmem.get_unique_id(empty=True)

if rank_id == 0:
uniqueid = nvshmem.get_unique_id()
broadcast_objects = [uniqueid]
else:
broadcast_objects = [None]

dist.broadcast_object_list(
broadcast_objects,
src=0,
group=cp_group
)

dist.barrier(group=cp_group)

nvshmem.init(
device=dev,
uid=broadcast_objects[0],
rank=rank_id,
nranks=num_ranks,
initializer_method="uid"
)

return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVSHMEM initialized unconditionally on every forward pass

Despite its name torchrun_uid_init_bcast_object_no_reinit, this function calls nvshmem.init(...) unconditionally (line 457) with no guard against re-initialization. The function is invoked inside the forward method (line 730) on every forward pass. Calling nvshmem.init() multiple times may cause undefined behavior or errors depending on the NVSHMEM pybinding version.

Move NVSHMEM initialization to module-load time or wrap the call in a one-time initialization guard:

_nvshmem_initialized = False

def _maybe_init_nvshmem(cp_group):
    global _nvshmem_initialized
    if _nvshmem_initialized:
        return
    torchrun_uid_init_bcast_object_no_reinit(cp_group)
    _nvshmem_initialized = True

Then call _maybe_init_nvshmem(cp_group) instead of torchrun_uid_init_bcast_object_no_reinit(cp_group) at line 730.

import torch.distributed as dist
from cuda.core.experimental import Device
from cuda.core.experimental import Stream
from nvshmem.core.interop.torch import tensor_get_buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import tensor_get_buffer

tensor_get_buffer is imported at line 14 but is not referenced anywhere else in the module. This creates an unnecessary dependency and will trigger linting warnings.

Remove this unused import:

Suggested change
from nvshmem.core.interop.torch import tensor_get_buffer
# Removed: from nvshmem.core.interop.torch import tensor_get_buffer

@sbhavani sbhavani added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Mar 5, 2026
Comment on lines +105 to +148
def torchrun_uid_init_bcast_object_no_reinit(cp_group=None):
local_rank = torch.cuda.current_device()
dev = Device(local_rank)
dev.set_current()

if cp_group is None:
rank_id = dist.get_rank()
num_ranks = dist.get_world_size()
else:
rank_id = dist.get_rank(group=cp_group)
num_ranks = dist.get_world_size(group=cp_group)

uniqueid = nvshmem.get_unique_id(empty=True)

if rank_id == 0:
uniqueid = nvshmem.get_unique_id()
broadcast_objects = [uniqueid]
else:
broadcast_objects = [None]

dist.broadcast_object_list(
broadcast_objects,
src=0,
group=cp_group
)

dist.barrier(group=cp_group)

nvshmem.init(
device=dev,
uid=broadcast_objects[0],
rank=rank_id,
nranks=num_ranks,
initializer_method="uid"
)

return True

def torchrun_uid_init_bcast_object_no_reinit(cp_group=None):
# 设备已经在外面 set 过最好
local_rank = torch.cuda.current_device()
dev = Device(local_rank)
dev.set_current()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate function definition silently discards first implementation

torchrun_uid_init_bcast_object_no_reinit is defined twice in this file (once at ~line 105 and again at ~line 143). In Python, the second definition completely replaces the first at module load time without any error or warning. The second definition also contains Chinese-language inline comments (# 设备已经在外面 set 过最好 and # 不要再 init_process_group !!!), strongly suggesting it is a copy-paste development artifact that was never cleaned up.

Remove the first definition and consolidate into a single, clean function — or if the two definitions are intentionally different, unify them.

Comment on lines +2089 to +2130
if nvshmem_kv is not None:
# owner of the next KV block
owner_idx = (rank - (i + 1)) % cp_size
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]
tex.nvshmem_get_on_current_stream(recv_tensor[0], nvshmem_kv, int(owner_global))
send_recv_reqs = []
else:
send_recv_reqs = flash_attn_p2p_communicate(
rank,
send_tensor[0],
send_dst,
recv_tensor[0],
recv_src,
ctx.cp_group,
batch_p2p_comm,
)
else:
dkv_a2a_req = torch.distributed.all_to_all_single(
dkv_fp8,
dkv_fp8_,
group=ctx.cp_group,
async_op=True,
)
send_recv_reqs = [dkv_a2a_req]
else:
if i == 0:
send_tensor = send_tensor[0]
recv_tensor = recv_tensor[0]
if i == (cp_size - 1):
send_tensor = send_tensor[1]
recv_tensor = recv_tensor[1]
if nvshmem_kv is not None:
owner_idx = (rank - (i + 1)) % cp_size
owner_global = cp_global_ranks[owner_idx * cp_size_a2a + rank_a2a]
tex.nvshmem_get_on_current_stream(recv_tensor, nvshmem_kv, int(owner_global))
send_recv_reqs = []
else:
send_recv_reqs = flash_attn_p2p_communicate(
rank, send_tensor, send_dst, recv_tensor, recv_src, ctx.cp_group, batch_p2p_comm
)

kv = p2p_comm_buffers[i % 2][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvshmem_kv used in backward but never defined there — NameError at runtime

In the backward method, nvshmem_kv is referenced at multiple points (if nvshmem_kv is not None:, tex.nvshmem_get_on_current_stream(recv_tensor[0], nvshmem_kv, ...)), but it is never assigned within the backward function.

nvshmem_kv was allocated in the forward pass and then explicitly freed at the end of forward via nvshmem.free_tensor(nvshmem_kv). It is not saved to ctx at all. Accessing an unresolved name will raise NameError: name 'nvshmem_kv' is not defined on the first backward call.

The fix depends on the intended design:

  • If NVSHMEM KV buffers should be re-allocated in the backward pass, add the allocation at the start of backward.
  • If the backward should fall back to P2P when no NVSHMEM buffer is available, initialise it to None at the top of backward:
nvshmem_kv = None  # backward allocates its own buffer or falls back to P2P

nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
"nvidia-nvshmem-cu12"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print left in build script

This print("CMAKE_FLAGS:", cmake_flags[-2:]) is a debug statement that will be emitted on every build where NVTE_WITH_CUBLASMP=1. It should be removed before merging to keep build output clean.

Suggested change
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")

Comment on lines +431 to +434
NVTE_CHECK_CUBLASMP(
std::apply(cublasMpMatmul,
std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size,
workspace_host.data(), workspace_host.size()})));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvshmem_malloc return value not checked — silent null-pointer dereference

nvshmem_malloc can return nullptr on allocation failure (e.g. symmetric heap exhausted), but the result is stored without any null check. ctx->workspace is then passed directly to cublasMpMatmul, which will dereference the null pointer and produce undefined behaviour or a hard crash.

Add a null check after allocation:

  if (ctx->workspace_size < wrksp_size_device) {
    nvshmem_free(ctx->workspace);
    ctx->workspace = nvshmem_malloc(wrksp_size_device);
    NVTE_CHECK(ctx->workspace != nullptr,
               "nvshmem_malloc failed to allocate workspace of size ", wrksp_size_device);
    ctx->workspace_size = wrksp_size_device;
  }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.