From 3837b9791d9b453bb59fa7315412fee4254b8eeb Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Wed, 11 Mar 2026 15:19:04 -0700 Subject: [PATCH] /s/representative_value/representative_value_or_total_size. PiperOrigin-RevId: 882223023 --- src/maxtext/layers/moe.py | 2 +- src/maxtext/models/deepseek_batchsplit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 70e67490f5..a7694a326e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -898,7 +898,7 @@ def gmm( ): tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, - representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + representative_value_or_total_size=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), ) pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 7cbbe17063..d585b4035e 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -807,7 +807,7 @@ def gmm( tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, - representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + representative_value_or_total_size=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), ) if config.use_qwix_quantization: output = megablox.gmm(