From 4bc91e00c44bf46f02749fc6209fba083514c6dc Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Thu, 12 Mar 2026 04:00:24 +0000 Subject: [PATCH] update tokamax group sizes for pipeline --- src/maxtext/layers/moe.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 3a54a2a900..56d98204ea 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -896,10 +896,14 @@ def sparse_matmul( def gmm( inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes ): - tokamax_group_sizes = tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), - ) + # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm + if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: + tokamax_group_sizes = group_sizes + else: + tokamax_group_sizes = tokamax.RaggedDotGroupSizes( + group_sizes, + max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + ) pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call