diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 70e67490f5..a7694a326e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -898,7 +898,7 @@ def gmm( ): tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, - representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + representative_value_or_total_size=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), ) pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 7cbbe17063..d585b4035e 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -807,7 +807,7 @@ def gmm( tokamax_group_sizes = tokamax.RaggedDotGroupSizes( group_sizes, - representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), + representative_value_or_total_size=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), ) if config.use_qwix_quantization: output = megablox.gmm(