Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class MaxTextForCausalLM(nnx.Module):
of the decoding step.
"""

# Signal to tpu-inference model_loader that this class manages its own
# JIT-sharded initialization (via create_nnx_model with out_shardings).
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
_self_manages_sharding: bool = True

def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.

Expand Down Expand Up @@ -232,7 +237,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
if self.model is not None:
return

with self.mesh, nn.logical_axis_rules(""):
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
model, _ = model_creation_utils.create_nnx_model(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
Expand Down
7 changes: 4 additions & 3 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,13 +956,14 @@ def forward_serve_vllm(
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
) from e

if rpa_kv_cache is None or rpa_metadata is None:
raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.")

query = query.reshape(-1, query.shape[2], query.shape[3])
key = key.reshape(-1, key.shape[2], key.shape[3])
value = value.reshape(-1, value.shape[2], value.shape[3])

if rpa_kv_cache is None or rpa_metadata is None:
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
return [], query

if self.config.sliding_window_size > 0:
attention_chunk_size = self.config.sliding_window_size
else:
Expand Down
65 changes: 35 additions & 30 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,13 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk,
page_state,
slot,
)
in_axes_tuple = (nn.broadcast,) * len(broadcast_args)
# Pipeline module only accepts (segment_ids, positions, deterministic, model_mode)
pipeline_broadcast_args = broadcast_args[:4]
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
logical_partition_spec = self.pipeline_module.get_weight_sharding(
Expand Down Expand Up @@ -828,9 +834,9 @@ def __call__(
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
else: # Not DeepSeek
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec)
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
if remaining_layers > 0:
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
Expand All @@ -847,26 +853,12 @@ def __call__(
else:
if cfg.scan_layers:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
layer_call_kwargs = {
"page_state": page_state,
"previous_chunk": previous_chunk,
"slot": slot,
}
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
if cfg.engram_layers:
original_dense_call = dense_layer.__call__
original_moe_call = moe_layer.__call__
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)

common_kwargs = {
"dense_layer": dense_layer,
"moe_layer": moe_layer,
"original_dense_call": original_dense_call,
"original_moe_call": original_moe_call,
"layer_call_kwargs": layer_call_kwargs,
"decoder_segment_ids": decoder_segment_ids,
"decoder_positions": decoder_positions,
"deterministic": deterministic,
Expand Down Expand Up @@ -895,7 +887,6 @@ def __call__(
**common_kwargs,
)
else:
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
y, _ = self.scan_decoder_layers(
cfg,
dense_layer,
Expand All @@ -905,7 +896,6 @@ def __call__(
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
)(y, *broadcast_args)
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers

# If batch-split schedule is used and initialization is complete,
Expand Down Expand Up @@ -954,16 +944,38 @@ def __call__(
"nope_layer_interval": self.config.nope_layer_interval,
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
}
y, _ = self.scan_decoder_layers(

# Update broadcast_args and in_axes_tuple for vLLM RPA
current_broadcast_args = list(broadcast_args)
current_in_axes_tuple = list(in_axes_tuple)

if kv_caches is not None:
# Stack kv_caches for scan: [num_layers, ...]
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
current_broadcast_args.append(stacked_kv_cache)
current_in_axes_tuple.append(0) # Scan over the layer dimension
else:
current_broadcast_args.append(None)
current_in_axes_tuple.append(nn.broadcast)

current_broadcast_args.append(attention_metadata)
current_in_axes_tuple.append(nn.broadcast)

y, returned_kv_cache = self.scan_decoder_layers(
cfg,
RemattedBlockLayer,
scan_length,
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
in_axes_tuple=tuple(current_in_axes_tuple),
model_mode=model_mode,
**layer_kwargs,
)(y, *broadcast_args)
)(y, *current_broadcast_args)

if kv_caches is not None and returned_kv_cache is not None:
# Update the list of KV caches from the scanned results
for i, cache in enumerate(returned_kv_cache):
kv_caches[i] = cache
else:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek."
Expand Down Expand Up @@ -1173,10 +1185,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
"""Applies a single, unscanned Engram layer."""
layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"]
layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"]
layer_call_kwargs = kwargs["layer_call_kwargs"]
broadcast_args = kwargs["broadcast_args"]

layer.__call__ = original_call
y, _ = layer(
config=self.config,
mesh=self.mesh,
Expand All @@ -1186,14 +1196,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
layer_idx=current_idx,
)(
y,
kwargs["decoder_segment_ids"],
kwargs["decoder_positions"],
kwargs["deterministic"],
kwargs["model_mode"],
*broadcast_args,
decoder_input_tokens=kwargs["decoder_input_tokens"],
**layer_call_kwargs,
)
layer.__call__ = functools.partial(original_call, **layer_call_kwargs)
return y

def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from maxtext.layers.linears import Dropout, MlpBlock
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import page_manager
from maxtext.utils import max_utils


Expand Down Expand Up @@ -126,8 +127,7 @@ def __call__(
deterministic,
model_mode,
previous_chunk=None,
page_manager=None,
page_state=None,
page_state: None | page_manager.PageState = None,
slot=None,
kv_cache=None,
attention_metadata=None,
Expand Down
22 changes: 14 additions & 8 deletions src/maxtext/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from maxtext.layers.attentions import Attention
from maxtext.layers.normalizations import RMSNorm
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import page_manager
from maxtext.utils import max_utils

# -----------------------------------------
Expand Down Expand Up @@ -138,7 +139,7 @@ def __call__(
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
page_state: None | page_manager.PageState = None,
slot=None,
kv_cache=None,
attention_metadata=None,
Expand Down Expand Up @@ -258,6 +259,11 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config

Expand All @@ -267,19 +273,19 @@ def __call__(
for layer_id in range(cfg.inhomogeneous_layer_cycle_interval):
layer_name = f"layers_{layer_id}"
layer = getattr(self, layer_name)
y = layer(
y, kv_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if cfg.scan_layers:
y = y[0]
if cfg.scan_layers:
return y, None
else:
return y
return y, kv_cache


GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class(
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
previous_chunk=None,
kv_cache=None,
attention_metadata=None,
):
Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
previous_chunk=None,
kv_cache=None,
attention_metadata=None,
):
Expand Down Expand Up @@ -570,9 +570,11 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
previous_chunk=None,
kv_cache=None,
attention_metadata=None,
):

cfg = self.config
Expand All @@ -590,6 +592,8 @@ def __call__(
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if cfg.scan_layers:
y = y[0]
Expand Down
5 changes: 3 additions & 2 deletions src/maxtext/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Config
from maxtext.inference import page_manager
from maxtext.layers import initializers, nnx_wrappers
from maxtext.layers import quantizations
from maxtext.layers.attentions import Attention
Expand Down Expand Up @@ -126,9 +127,9 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
page_state: None | int = None,
slot: None | int = None,
previous_chunk=None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
kv_cache=None,
attention_metadata=None,
):
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/models/olmo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,11 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
slot=None,
kv_cache=None,
attention_metadata=None,
):
cfg = self.config

Expand All @@ -282,6 +287,11 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
if cfg.scan_layers:
y = y[0]
Expand Down
14 changes: 6 additions & 8 deletions src/maxtext/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ def __call__(
previous_chunk=None,
page_state: None | page_manager.PageState = None,
slot: None | int = None,
kv_cache=None,
attention_metadata=None,
) -> tuple[Array, None]:
"""Applies the block of decoder layers to the input carry.

Expand Down Expand Up @@ -924,6 +926,8 @@ def __call__(
previous_chunk,
page_state,
slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)

# The output of the block is the carry for the next scan iteration.
Expand Down Expand Up @@ -1235,10 +1239,7 @@ def __call__(
layer_output = intermediate_inputs + mlp_lnx
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
return layer_output, kv_cache


# -----------------------------------------
Expand Down Expand Up @@ -1304,10 +1305,7 @@ def __call__(
layer_output = intermediate_inputs + mlp_lnx
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.scan_layers:
return layer_output, None
else:
return layer_output, kv_cache
return layer_output, kv_cache


class Qwen3OmniMoeVisionPatchMerger(nnx.Module):
Expand Down
Loading
Loading