Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability
float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device
float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability

# multi-token prediction configs
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ class MoEGeneral(BaseModel):
description="Enable top-k probability normalization for router weights (Qwen3-specific).",
)
float32_weight_sum: bool = Field(
True,
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
False,
description="Whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device.",
)
float32_gate_logits: bool = Field(
False,
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/kernels/gather_reduce_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __getitem__(self, shape):
_BF16 = VectorTypeHelper(ir.BF16Type.get)


@jax.jit(
@functools.partial(
jax.jit,
static_argnames=[
"reduce_group_size",
"single_sc",
Expand Down
25 changes: 16 additions & 9 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,13 +1580,22 @@ def cudnn_flash_attention(
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)
# Default case: no packing, no context parallelism.
# For synthetic data, segment IDs are always all-ones (one segment per sequence), so
# the segment mask is all-True and the combined mask reduces to pure causal masking.
# Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that
# XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory).
if self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
dummy_attn_mask = jnp.zeros(
(1, 1, 1, self.max_target_length, self.max_target_length),
dtype=jnp.uint8,
)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8)

dpa_layer = DotProductAttention(
head_dim=head_dim,
Expand All @@ -1599,12 +1608,10 @@ def cudnn_flash_attention(
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout,
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis=self.config.context_sharding,
context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def __init__(
mesh=mesh,
shard_mode=config.shard_mode,
debug_sharding=config.debug_sharding,
skip_trivial_specs=True,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
Expand Down
37 changes: 21 additions & 16 deletions src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding

from flax import linen as nn
from flax import nnx

from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType
Expand Down Expand Up @@ -156,30 +157,34 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
self.dtype,
)

output_axis_names = (
(
"activation_embed_and_logits_batch",
"prefill_activation_length",
"activation_embed",
)
if model_mode == MODEL_MODE_PREFILL
else (
"activation_embed_and_logits_batch",
"activation_length",
"activation_embed",
)
)
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None))
output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed")
output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed")

if self.config.shard_mode == ShardMode.EXPLICIT:
output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None))
out_sharding = NamedSharding(self.mesh, out_pspec)
else:
out_sharding = None

out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None
one_hot_elements = 1
for d in inputs.shape:
one_hot_elements *= d
one_hot_elements *= self.num_embeddings
one_hot_bytes = one_hot_elements * jnp.dtype(self.dtype).itemsize
use_iota = cfg.use_iota_embed and one_hot_bytes <= 2 * 1024**3

if cfg.use_iota_embed:
if use_iota:
iota = lax.iota(jnp.int32, self.num_embeddings)
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
output = jnp.dot(one_hot, embedding, out_sharding=out_sharding)
else:
output = embedding.at[inputs].get(out_sharding=out_sharding)

if model_mode == MODEL_MODE_PREFILL:
output = nn.with_logical_constraint(output, output_prefill_axis_names)
else:
output = nn.with_logical_constraint(output, output_default_axis_names)
return output

def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array:
Expand Down
18 changes: 4 additions & 14 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
else:
layer_act = self.activation_fn(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
return intermediate_layer.astype(self.dtype)
return intermediate_layer

def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None):
"""Permute tokens to group by expert to fit gmm call."""
Expand Down Expand Up @@ -1444,27 +1444,17 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
group_sizes=group_sizes,
expert_assignments=selected_experts,
)
# Forward-only 3-tuple tiling. The 9-tuple form includes dlhs/drhs backward-pass
# tile values that are allocated but unused when megablox=False (JAX ragged_dot path).
wi_tile_size = (
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
self.config.wi_tile_fwd_embed_dim, # k (contracting)
self.config.wi_tile_fwd_embed_dim, # k (contracting)
self.config.wi_tile_fwd_mlp_dim, # n (RHS batch)
self.config.wi_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wi_tile_dlhs_mlp_dim, # k (contracting)
self.config.wi_tile_dlhs_embed_dim, # n (RHS batch)
self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
)
wo_tile_size = (
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
self.config.wo_tile_fwd_embed_dim, # n (RHS batch)
self.config.wo_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wo_tile_dlhs_embed_dim, # k (contracting)
self.config.wo_tile_dlhs_mlp_dim, # n (RHS)
self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)

layer_w0 = gmm_fn(
Expand Down
7 changes: 5 additions & 2 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
scale = jax.device_put(scale, max_utils.device_space())

scale = jnp.asarray(scale, self.dtype)
effective_scale = scale + self.scale_offset
return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding)
effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale
y = y * effective_scale
if out_sharding is not None:
y = jax.lax.with_sharding_constraint(y, out_sharding)
return y


class GlobalRMSNorm(RMSNorm):
Expand Down
71 changes: 18 additions & 53 deletions src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes):
rules=self.config.logical_axis_rules,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
skip_trivial_specs=True,
)

def _maybe_shard_with_name(self, inputs, sharding_name):
Expand All @@ -139,7 +140,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):
# Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage)
state_io_batch_idx = loop_iteration % self.microbatches_per_stage
state_io_slice = state_io[:, state_io_batch_idx]
shift = self._maybe_shard_with_logical(shift, self.stages_in_logical)

if self.use_circ_storage:
# Setup potential input from circ_storage, which also has a rotating index for microbatch,
Expand All @@ -154,7 +154,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):
# state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g.
# from circ_storage).
first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in)
first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical)

# Note that first_stage_in may correspond to bubble computation during the last few iterations.
# However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are
Expand All @@ -164,11 +163,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift):

def select_state_or_input(first_stage_in, shift):
# Selects input for stage 0, shift for other stages
return jnp.where(
jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0,
first_stage_in,
shift,
)
return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift)

# Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output)
stages_in = select_state_or_input(first_stage_in, shift)
Expand All @@ -180,7 +175,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
non-circular"""
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
return microbatch_ids, repeat_ids
Expand All @@ -190,7 +184,7 @@ def get_pipeline_remat_policy(self):
if self.config.remat_policy == "custom":
return self.remat_policy

save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input")
save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input")
if self.remat_policy is not None:
remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy)
else:
Expand Down Expand Up @@ -247,16 +241,6 @@ def get_main_vmap_func_for_iterations(self):
def func_to_vmap(
body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode
):
weights = meta.remove_axis(
weights,
0,
{
nn.PARTITION_NAME: "layers",
"sub_weight_split_dims_mapping": (None,),
"is_initializing": self.is_initializing(),
"x_times": self.num_stages,
},
)
return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode)

vmap_func = nn.vmap(
Expand Down Expand Up @@ -490,9 +474,7 @@ def vmap_gather(self, xs, ids, ids_dim):
"""

def _gather_one(x, i):
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
replicated_sharding = NamedSharding(self.mesh, P())
return x.at[idx].get(out_sharding=replicated_sharding)
return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim)

ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None)
outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
Expand All @@ -516,21 +498,16 @@ def get_new_loop_state(self, output, loop_state):
loop_iteration = loop_state["loop_iteration"]
old_prev_outputs = loop_state["prev_outputs"]

@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _rotate_right(arr):
# we use +1 for right shifting
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return arr
# Use lax.slice to avoid generating a gather.
last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0)
except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0)
return jnp.concatenate([last, except_last], axis=0)

@jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True)
def _shift_right(arr):
stage_idx = jax.lax.axis_index("stage")
stage_size = jax.lax.axis_size("stage")
perm = [(i, (i + 1) % stage_size) for i in range(stage_size)]
arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm)
return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr)
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
# Use lax.slice to guarantee the gradient is a pad.
return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)

# Shift either rotates or shifts depending on if the last stage immediately must send to first or not
# For non-circular pipelines, the last stage does not need to send to first
Expand Down Expand Up @@ -574,29 +551,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in):
stream_buf_idx = loop_iteration % self.microbatches_per_stage
stream_slice = old_state_io[:, stream_buf_idx]

def _rotate_left(arr, stage_size):
# we use -1 for left shifting
perm = [(i, (i - 1) % stage_size) for i in range(stage_size)]
return jax.lax.ppermute(arr, axis_name="stage", perm=perm)

def _shift_left(arr, stage_size, output):
stage_idx = jax.lax.axis_index("stage")
arr = _rotate_left(arr, stage_size)
return jnp.where(stage_idx == stage_size - 1, output, arr)

@jax.shard_map(
mesh=self.mesh,
in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()),
out_specs=self.state_io_spec,
)
def _update_state_io(state_in, stream_slice, output, stream_buf_idx):
def _update_state_io(state_in, stream_slice, output):
# Shift the current slice to the left, then fill the last stage with the final output.
stage_size = jax.lax.axis_size("stage")
stream_slice = _shift_left(stream_slice, stage_size, output)
padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1)
stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0)
stream_slice = jnp.where(
jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice
)
stream_slice = jnp.expand_dims(stream_slice, 1)
return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1)

new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx)
new_state = _update_state_io(old_state_io, stream_slice, output)

new_loop_state = {
"state_io": new_state,
Expand Down
19 changes: 9 additions & 10 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from maxtext.utils import max_utils
from maxtext.utils.sharding import create_sharding
from maxtext.utils.sharding import maybe_shard_with_logical
from maxtext.utils.sharding import remove_size_one_mesh_axis

import transformers

Expand Down Expand Up @@ -419,7 +420,7 @@ def __init__(
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
config=self.config,
mesh=mesh,
kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
Expand Down Expand Up @@ -492,15 +493,14 @@ def __call__(
return outputs, None

# bf16 and fp8 code path for pure-JAX batch-split.
# fp8 code path supports both manual quantization and qwix
# quantization.
input_sharding = jax.typeof(inputs).sharding
activation_pspec = jax.sharding.PartitionSpec(
("data", "fsdp", "expert"),
None,
None,
activation_pspec = remove_size_one_mesh_axis(
jax.sharding.PartitionSpec(
("data", "fsdp", "fsdp_transpose", "expert", "context"),
None,
None,
),
self.mesh,
)
inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec))
yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs(
decoder_positions,
embedding_dims=self.config.qk_rope_head_dim,
Expand Down Expand Up @@ -572,7 +572,6 @@ def extract_fn(x):
in_specs=([activation_pspec] * self.config.batch_split_factor,),
out_specs=activation_pspec,
)(outputs)
outputs = jax.reshard(outputs, input_sharding)
return outputs, None

x = self.with_logical_constraint(inputs)
Expand Down
Loading
Loading