Skip to content

feat: implement Group.split() for JACCL and Ring backends#3245

Open
machiabeli wants to merge 1 commit intoml-explore:mainfrom
machiabeli:feat/group-split-jaccl-ring
Open

feat: implement Group.split() for JACCL and Ring backends#3245
machiabeli wants to merge 1 commit intoml-explore:mainfrom
machiabeli:feat/group-split-jaccl-ring

Conversation

@machiabeli
Copy link

Summary

Implements Group.split() (MPI_Comm_split semantics) for the JACCL mesh, JACCL ring, and TCP ring distributed backends. This enables mixed parallelism strategies — such as combining tensor parallelism with pipeline parallelism — on Apple Silicon clusters connected via Thunderbolt 5 RDMA.

Addresses #3205

Problem

Group.split() was unimplemented for all Apple Silicon distributed backends (throw std::runtime_error("Group split not supported")), blocking any use of sub-group collectives. This prevented:

  • Mixed tensor + pipeline parallelism
  • Expert parallelism for MoE models
  • Any workflow requiring communicator subdivision (standard in MPI/NCCL)

Key Design Decisions

1. TCPGroup Fallback for JACCL Sub-Groups

Apple's Thunderbolt 5 RDMA driver does not support multiple ibv_context instances on the same physical RDMA device simultaneously. If a sub-group tries to open new RDMA connections on devices already held by the parent group, the driver deadlocks.

Solution: Sub-groups created via split() on JACCL backends use TCPGroup — a new group implementation that performs all collective operations over TCP (via SideChannel). This is slower than RDMA but:

  • Avoids the ibv_context concurrency deadlock entirely
  • Keeps RDMA reserved for the parent group's high-bandwidth tensor parallelism
  • Matches how MPI implementations handle transport fallback (e.g., MPICH's ch3:sock)

2. Strict Collective Synchronization

All SideChannel::all_gather calls in split() execute before any rank makes a branching decision (e.g., returning a LocalGroup for size-1 sub-groups). This matches MPI_Comm_split's collective semantics — the MPICH implementation internally uses MPI_Allgather with the same constraint.

3. LocalGroup for Size-1 Sub-Groups

Ranks that end up alone in their color group get a lightweight no-op implementation (identity for reductions, memcpy for gather). No RDMA or TCP connections opened.

4. Coordinator Address Derivation

Sub-group coordinators reuse the parent's coordinator host with port offsets based on the color parameter: parent_port + 1000 + color. This ensures no port collisions between concurrent sub-groups.

Changes

File What
mlx/distributed/jaccl/mesh.h Store device_names_, coordinator_host_, coordinator_port_ for split()
mlx/distributed/jaccl/mesh.cpp MeshGroup::split() + LocalGroup class
mlx/distributed/jaccl/ring.h Store all_devices_, coordinator info for split()
mlx/distributed/jaccl/ring.cpp RingGroup::split() + TCPGroup (full collective ops over TCP) + RingLocalGroup
mlx/distributed/ring/ring.cpp RingGroup::split() using ring all-gather + port-offset sub-ring creation
tests/test_jaccl_split.py 5 tests: same-color, two-group, three-group, key ordering, send/recv

Testing

Tested on a 4-node Apple Silicon cluster:

  • 3× Mac Studio M3 Ultra (512GB, 256GB, 256GB)
  • 1× Mac Studio M4 Max (128GB)
  • Connected via Thunderbolt 5 with RDMA (jaccl-ring backend)
mlx.launch --hostfile /tmp/hostfile-jaccl-ring-4node.json tests/test_jaccl_split.py

All 5 tests pass: same-color split, even/odd two-group split, three-group split with size-1 sub-groups, key-reversed ordering, and send/recv on sub-groups (gracefully skipped for TCPGroup which doesn't support point-to-point).

Hardware Context

This was developed and tested against Apple's first-generation Thunderbolt 5 RDMA driver (macOS 26.x, infiniband/verbs.h from Xcode SDK). The ibv_context concurrency limitation is specific to this driver — standard InfiniBand hardware (Mellanox ConnectX) supports multiple contexts via SR-IOV. The TCPGroup fallback is designed to be replaced with direct RDMA sub-group connections if/when Apple adds multi-context support.

Adds `Group.split()` (MPI_Comm_split semantics) to the JACCL mesh, JACCL
ring, and TCP ring distributed backends — enabling mixed parallelism
strategies (e.g. tensor parallelism + pipeline parallelism) on Apple
Silicon clusters.

Key design decisions:

1. **TCPGroup fallback for JACCL sub-groups**: Apple's Thunderbolt 5 RDMA
   driver does not support multiple `ibv_context` instances on the same
   physical device simultaneously. Sub-groups derived from an RDMA parent
   would deadlock if they tried to open new RDMA connections. Solution:
   sub-groups use TCP (via SideChannel) for collective operations, keeping
   RDMA reserved for the parent group's high-bandwidth tensor parallelism.

2. **SideChannel synchronization**: All ranks must call the same sequence
   of `SideChannel::all_gather` operations before any rank branches or
   returns early. This matches MPI_Comm_split's collective semantics and
   prevents deadlocks from asymmetric participation.

3. **LocalGroup for size-1 sub-groups**: Ranks that end up alone in their
   color group get a lightweight no-op implementation that avoids both
   RDMA and TCP overhead.

4. **Coordinator derivation**: Sub-group coordinators are derived from the
   parent's coordinator address with port offsets based on the color
   parameter, ensuring no port collisions between sub-groups.

Backends modified:
- `jaccl/mesh.cpp`: MeshGroup::split() with RDMA-aware sub-group creation
- `jaccl/ring.cpp`: RingGroup::split() + TCPGroup (full collective ops
  over TCP) + RingLocalGroup
- `ring/ring.cpp`: RingGroup::split() using ring all-gather for info
  exchange and port-offset addressing for sub-ring creation

Tested on a 4-node Apple Silicon cluster (3x M3 Ultra + 1x M4 Max)
connected via Thunderbolt 5 RDMA using the jaccl-ring backend.

Closes ml-explore#3205
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant