Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,10 +896,14 @@ def sparse_matmul(
def gmm(
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
):
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
group_sizes,
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
)
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
tokamax_group_sizes = group_sizes
else:
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
group_sizes,
max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
)
pad_length = self.config.wi_tile_fwd_batch_seq
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
Expand Down
Loading