From c90eb8e47758560309b50bded5b411e4387456a8 Mon Sep 17 00:00:00 2001 From: cj401-amd Date: Sat, 16 May 2026 04:24:19 +0800 Subject: [PATCH] update for tmem clean --- src/maxtext/configs/base.yml | 2 +- src/maxtext/configs/types.py | 4 +- src/maxtext/kernels/gather_reduce_sc.py | 3 +- src/maxtext/layers/attention_op.py | 25 +-- src/maxtext/layers/attentions.py | 1 + src/maxtext/layers/embeddings.py | 37 +++-- src/maxtext/layers/moe.py | 18 +-- src/maxtext/layers/normalizations.py | 7 +- src/maxtext/layers/pipeline.py | 71 +++------ src/maxtext/models/deepseek.py | 19 ++- src/maxtext/models/mixtral.py | 200 +++++++++++------------- src/maxtext/trainers/pre_train/train.py | 14 +- src/maxtext/utils/sharding.py | 13 +- 13 files changed, 194 insertions(+), 220 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index ecf03133cc..c26bd8afaf 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -185,7 +185,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax -float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability +float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability # multi-token prediction configs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index ec848e6207..da18a620bc 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -712,8 +712,8 @@ class MoEGeneral(BaseModel): description="Enable top-k probability normalization for router weights (Qwen3-specific).", ) float32_weight_sum: bool = Field( - True, - description="Whether to use full fp32 precision to sum expert weights for numerical stability.", + False, + description="Whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device.", ) float32_gate_logits: bool = Field( False, diff --git a/src/maxtext/kernels/gather_reduce_sc.py b/src/maxtext/kernels/gather_reduce_sc.py index 8805be1430..393f554423 100644 --- a/src/maxtext/kernels/gather_reduce_sc.py +++ b/src/maxtext/kernels/gather_reduce_sc.py @@ -55,7 +55,8 @@ def __getitem__(self, shape): _BF16 = VectorTypeHelper(ir.BF16Type.get) -@jax.jit( +@functools.partial( + jax.jit, static_argnames=[ "reduce_group_size", "single_sc", diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index e72cfe9134..4bd712e6c3 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1580,13 +1580,22 @@ def cudnn_flash_attention( dummy_attn_mask = None mask_type = "causal" else: - # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros( - (1, 1, 1, self.max_target_length, self.max_target_length), - dtype=jnp.uint8, - ) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + # Default case: no packing, no context parallelism. + # For synthetic data, segment IDs are always all-ones (one segment per sequence), so + # the segment mask is all-True and the combined mask reduces to pure causal masking. + # Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that + # XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory). + if self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" + else: + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), + dtype=jnp.uint8, + ) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1599,12 +1608,10 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, - context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index e53de0973a..dc3c593eeb 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -561,6 +561,7 @@ def __init__( mesh=mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + skip_trivial_specs=True, ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 525fff1ed5..36a24d119f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding +from flax import linen as nn from flax import nnx from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType @@ -156,30 +157,34 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: self.dtype, ) - output_axis_names = ( - ( - "activation_embed_and_logits_batch", - "prefill_activation_length", - "activation_embed", - ) - if model_mode == MODEL_MODE_PREFILL - else ( - "activation_embed_and_logits_batch", - "activation_length", - "activation_embed", - ) - ) - out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed") + output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed") + + if self.config.shard_mode == ShardMode.EXPLICIT: + output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names + out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + out_sharding = NamedSharding(self.mesh, out_pspec) + else: + out_sharding = None - out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None + one_hot_elements = 1 + for d in inputs.shape: + one_hot_elements *= d + one_hot_elements *= self.num_embeddings + one_hot_bytes = one_hot_elements * jnp.dtype(self.dtype).itemsize + use_iota = cfg.use_iota_embed and one_hot_bytes <= 2 * 1024**3 - if cfg.use_iota_embed: + if use_iota: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) output = jnp.dot(one_hot, embedding, out_sharding=out_sharding) else: output = embedding.at[inputs].get(out_sharding=out_sharding) + if model_mode == MODEL_MODE_PREFILL: + output = nn.with_logical_constraint(output, output_prefill_axis_names) + else: + output = nn.with_logical_constraint(output, output_default_axis_names) return output def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e23c3eba9f..4c65efa954 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -705,7 +705,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1): else: layer_act = self.activation_fn(layer_w0) intermediate_layer = jnp.multiply(layer_act, layer_w1) - return intermediate_layer.astype(self.dtype) + return intermediate_layer def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): """Permute tokens to group by expert to fit gmm call.""" @@ -1389,27 +1389,17 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): group_sizes=group_sizes, expert_assignments=selected_experts, ) + # Forward-only 3-tuple tiling. The 9-tuple form includes dlhs/drhs backward-pass + # tile values that are allocated but unused when megablox=False (JAX ragged_dot path). wi_tile_size = ( self.config.wi_tile_fwd_batch_seq, # m (LHS batch) - self.config.wi_tile_fwd_embed_dim, # k (contracting) + self.config.wi_tile_fwd_embed_dim, # k (contracting) self.config.wi_tile_fwd_mlp_dim, # n (RHS batch) - self.config.wi_tile_dlhs_batch_seq, # m (LHS batch) - self.config.wi_tile_dlhs_mlp_dim, # k (contracting) - self.config.wi_tile_dlhs_embed_dim, # n (RHS batch) - self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting - self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim - self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim ) wo_tile_size = ( self.config.wo_tile_fwd_batch_seq, # m (LHS batch) self.config.wo_tile_fwd_mlp_dim, # k (contracting) self.config.wo_tile_fwd_embed_dim, # n (RHS batch) - self.config.wo_tile_dlhs_batch_seq, # m (LHS batch) - self.config.wo_tile_dlhs_embed_dim, # k (contracting) - self.config.wo_tile_dlhs_mlp_dim, # n (RHS) - self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting - self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim - self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) layer_w0 = gmm_fn( diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..483d234f47 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -88,8 +88,11 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jax.device_put(scale, max_utils.device_space()) scale = jnp.asarray(scale, self.dtype) - effective_scale = scale + self.scale_offset - return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) + effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale + y = y * effective_scale + if out_sharding is not None: + y = jax.lax.with_sharding_constraint(y, out_sharding) + return y class GlobalRMSNorm(RMSNorm): diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 62ea52782b..e3b3acbbc8 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -118,6 +118,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): rules=self.config.logical_axis_rules, debug_sharding=self.config.debug_sharding, extra_stack_level=1, + skip_trivial_specs=True, ) def _maybe_shard_with_name(self, inputs, sharding_name): @@ -139,7 +140,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) if self.use_circ_storage: # Setup potential input from circ_storage, which also has a rotating index for microbatch, @@ -154,7 +154,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) - first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are @@ -164,11 +163,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): def select_state_or_input(first_stage_in, shift): # Selects input for stage 0, shift for other stages - return jnp.where( - jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, - first_stage_in, - shift, - ) + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) @@ -180,7 +175,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) - microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -190,7 +184,7 @@ def get_pipeline_remat_policy(self): if self.config.remat_policy == "custom": return self.remat_policy - save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input") if self.remat_policy is not None: remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) else: @@ -247,16 +241,6 @@ def get_main_vmap_func_for_iterations(self): def func_to_vmap( body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode ): - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) vmap_func = nn.vmap( @@ -490,9 +474,7 @@ def vmap_gather(self, xs, ids, ids_dim): """ def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) @@ -516,21 +498,16 @@ def get_new_loop_state(self, output, loop_state): loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return arr + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) # Shift either rotates or shifts depending on if the last stage immediately must send to first or not # For non-circular pipelines, the last stage does not need to send to first @@ -574,29 +551,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] - def _rotate_left(arr, stage_size): - # we use -1 for left shifting - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): + def _update_state_io(state_in, stream_slice, output): # Shift the current slice to the left, then fill the last stage with the final output. - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) stream_slice = jnp.expand_dims(stream_slice, 1) return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) + new_state = _update_state_io(old_state_io, stream_slice, output) new_loop_state = { "state_io": new_state, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 0980b78599..07f67b8561 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -43,6 +43,7 @@ from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding from maxtext.utils.sharding import maybe_shard_with_logical +from maxtext.utils.sharding import remove_size_one_mesh_axis import transformers @@ -419,7 +420,7 @@ def __init__( self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( config=self.config, mesh=mesh, - kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, @@ -492,15 +493,14 @@ def __call__( return outputs, None # bf16 and fp8 code path for pure-JAX batch-split. - # fp8 code path supports both manual quantization and qwix - # quantization. - input_sharding = jax.typeof(inputs).sharding - activation_pspec = jax.sharding.PartitionSpec( - ("data", "fsdp", "expert"), - None, - None, + activation_pspec = remove_size_one_mesh_axis( + jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert", "context"), + None, + None, + ), + self.mesh, ) - inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec)) yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs( decoder_positions, embedding_dims=self.config.qk_rope_head_dim, @@ -572,7 +572,6 @@ def extract_fn(x): in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) - outputs = jax.reshard(outputs, input_sharding) return outputs, None x = self.with_logical_constraint(inputs) diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..03c1152197 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -18,111 +18,33 @@ from flax import linen as nn -from flax import nnx from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.layers import initializers, nnx_wrappers +from maxtext.layers import initializers from maxtext.layers import moe from maxtext.layers import quantizations -from maxtext.layers.attentions import Attention -from maxtext.layers.linears import Dropout -from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.attentions import attention_as_linen +from maxtext.layers.normalizations import rms_norm from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils +from maxtext.utils.sharding import maybe_shard_with_logical # ----------------------------------------- # The Decoder Layer for Mixtral # ----------------------------------------- -class MixtralDecoderLayer(nnx.Module): +class MixtralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" - @nn.compact - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str, - quant: None | Quant = None, - *, - rngs: nnx.Rngs, - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.rngs = rngs - - batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) - dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) - - self.pre_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.self_attention = Attention( - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - attention_kernel=config.attention, - inputs_q_shape=dummy_inputs_shape, - inputs_kv_shape=dummy_inputs_shape, - mesh=mesh, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - dropout_rate=config.dropout_rate, - float32_qk_product=config.float32_qk_product, - float32_logits=config.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(config), - prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), - reshape_q=config.reshape_q, - use_ragged_attention=config.use_ragged_attention, - ragged_block_size=config.ragged_block_size, - model_mode=model_mode, - rngs=self.rngs, - ) - - self.post_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.MoeBlock_0 = moe.RoutedMoE( - config=config, - num_experts=config.num_experts, - num_experts_per_tok=config.num_experts_per_tok, - mesh=mesh, - kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), - intermediate_dim=config.mlp_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - quant=self.quant, - rngs=self.rngs, - ) - - self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) - - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + config: Config + mesh: Mesh + model_mode: str + quant: None | Quant = None + @nn.compact def __call__( self, inputs, @@ -139,13 +61,61 @@ def __call__( # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = self.pre_self_attention_layer_norm(inputs) - lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + cfg = self.config + mesh = self.mesh + + activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") - attention_lnx, kv_cache = self.self_attention( + def shard(x): + return maybe_shard_with_logical( + x, activation_axis_names, mesh=mesh, shard_mode=cfg.shard_mode, + rules=cfg.logical_axis_rules, skip_trivial_specs=True, + ) + + inputs = shard(inputs) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) + lnx = shard(lnx) + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(cfg, model_mode) + dummy_inputs_shape = (batch_size, seq_len, cfg.emb_dim) + + attention_lnx, kv_cache = attention_as_linen( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + use_ragged_attention=cfg.use_ragged_attention, + ragged_block_size=cfg.ragged_block_size, + model_mode=model_mode, + name="self_attention", + )( lnx, lnx, decoder_positions, @@ -157,28 +127,47 @@ def __call__( attention_metadata=attention_metadata, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + attention_lnx = shard(attention_lnx) intermediate_inputs = inputs + attention_lnx # Fully Connected - hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + hidden_states = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="post_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(intermediate_inputs) + hidden_states = shard(hidden_states) load_balance_loss = None # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + mlp_lnx, load_balance_loss, _ = moe.get_routed_moe( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.mlp_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + name="MoeBlock_0", + )(hidden_states) + mlp_lnx = shard(mlp_lnx) layer_output = mlp_lnx + intermediate_inputs - layer_output = self.dropout(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = shard(layer_output) - if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - if self.config.record_internal_nn_metrics: + if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( @@ -187,13 +176,10 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - if self.config.scan_layers: + if cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache -MixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class( - MixtralDecoderLayer, - base_metadata_fn=initializers.variable_to_logically_partitioned, -) +MixtralDecoderLayerToLinen = MixtralDecoderLayer diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 1011563a7b..b9a9604bac 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,6 +36,11 @@ import jax import jax.numpy as jnp +import flax +try: + flax.config.update("flax_always_shard_variable", False) +except LookupError: + pass from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -355,10 +360,11 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat is_train=True, ) - raw_grads = jax.tree_util.tree_map( - lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, - raw_grads, - ) + if config.grad_dtype != jnp.float32: + raw_grads = jax.tree_util.tree_map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) if config.parameter_memory_host_offload: raw_grads = jax.device_put( raw_grads, diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..38cdaa707d 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -131,7 +131,15 @@ def maybe_shard_with_pspec( def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" + inputs, + logical_axes, + mesh, + shard_mode, + rules=None, + debug_sharding=False, + extra_stack_level=0, + sharding_desc="", + skip_trivial_specs=False, ): """ A wrapper of maybe_shard_with_name when logical axes are inputs @@ -146,6 +154,9 @@ def maybe_shard_with_logical( named_sharding = create_sharding(mesh, logical_axes, rules=rules) + if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): + return inputs + return maybe_shard_with_name( inputs, named_sharding,