diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a0f3afba76..49bc08d48a 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -86,6 +86,11 @@ class MaxTextForCausalLM(nnx.Module): of the decoding step. """ + # Signal to tpu-inference model_loader that this class manages its own + # JIT-sharded initialization (via create_nnx_model with out_shardings). + # When True, model_loader skips wrapping __init__ in an outer bare @jax.jit, + _self_manages_sharding: bool = True + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): """Initializes the MaxTextForCausalLM model. @@ -232,7 +237,7 @@ def load_weights(self, rng_key: jax.Array) -> None: if self.model is not None: return - with self.mesh, nn.logical_axis_rules(""): + with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): model, _ = model_creation_utils.create_nnx_model( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 900bc3f617..c224af506e 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -956,13 +956,14 @@ def forward_serve_vllm( "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." ) from e - if rpa_kv_cache is None or rpa_metadata is None: - raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") - query = query.reshape(-1, query.shape[2], query.shape[3]) key = key.reshape(-1, key.shape[2], key.shape[3]) value = value.reshape(-1, value.shape[2], value.shape[3]) + if rpa_kv_cache is None or rpa_metadata is None: + # Return dummy values for dry runs (e.g. during model initialization or JIT tracing) + return [], query + if self.config.sliding_window_size > 0: attention_chunk_size = self.config.sliding_window_size else: diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 90cbe58c34..b20635e0d2 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -792,7 +792,13 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk, + page_state, + slot, ) + in_axes_tuple = (nn.broadcast,) * len(broadcast_args) + # Pipeline module only accepts (segment_ids, positions, deterministic, model_mode) + pipeline_broadcast_args = broadcast_args[:4] if cfg.using_pipeline_parallelism: if cfg.pipeline_fsdp_ag_once: logical_partition_spec = self.pipeline_module.get_weight_sharding( @@ -828,9 +834,9 @@ def __call__( in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) - y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) + y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec) else: # Not DeepSeek - y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) + y = self.pipeline_module(y, *pipeline_broadcast_args, logical_partition_spec=logical_partition_spec) remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers if remaining_layers > 0: logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) @@ -847,26 +853,12 @@ def __call__( else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - layer_call_kwargs = { - "page_state": page_state, - "previous_chunk": previous_chunk, - "slot": slot, - } dense_layer = RemattedBlockLayers[0] moe_layer = RemattedBlockLayers[1] if cfg.engram_layers: - original_dense_call = dense_layer.__call__ - original_moe_call = moe_layer.__call__ - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) - common_kwargs = { "dense_layer": dense_layer, "moe_layer": moe_layer, - "original_dense_call": original_dense_call, - "original_moe_call": original_moe_call, - "layer_call_kwargs": layer_call_kwargs, "decoder_segment_ids": decoder_segment_ids, "decoder_positions": decoder_positions, "deterministic": deterministic, @@ -895,7 +887,6 @@ def __call__( **common_kwargs, ) else: - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) y, _ = self.scan_decoder_layers( cfg, dense_layer, @@ -905,7 +896,6 @@ def __call__( in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers # If batch-split schedule is used and initialization is complete, @@ -954,16 +944,38 @@ def __call__( "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - y, _ = self.scan_decoder_layers( + + # Update broadcast_args and in_axes_tuple for vLLM RPA + current_broadcast_args = list(broadcast_args) + current_in_axes_tuple = list(in_axes_tuple) + + if kv_caches is not None: + # Stack kv_caches for scan: [num_layers, ...] + stacked_kv_cache = jnp.stack(kv_caches, axis=0) + current_broadcast_args.append(stacked_kv_cache) + current_in_axes_tuple.append(0) # Scan over the layer dimension + else: + current_broadcast_args.append(None) + current_in_axes_tuple.append(nn.broadcast) + + current_broadcast_args.append(attention_metadata) + current_in_axes_tuple.append(nn.broadcast) + + y, returned_kv_cache = self.scan_decoder_layers( cfg, RemattedBlockLayer, scan_length, "layers", mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), + in_axes_tuple=tuple(current_in_axes_tuple), model_mode=model_mode, **layer_kwargs, - )(y, *broadcast_args) + )(y, *current_broadcast_args) + + if kv_caches is not None and returned_kv_cache is not None: + # Update the list of KV caches from the scanned results + for i, cache in enumerate(returned_kv_cache): + kv_caches[i] = cache else: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." @@ -1173,10 +1185,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs): """Applies a single, unscanned Engram layer.""" layer = kwargs["dense_layer"] if layer_type == "dense" else kwargs["moe_layer"] layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers" - original_call = kwargs["original_dense_call"] if layer_type == "dense" else kwargs["original_moe_call"] - layer_call_kwargs = kwargs["layer_call_kwargs"] + broadcast_args = kwargs["broadcast_args"] - layer.__call__ = original_call y, _ = layer( config=self.config, mesh=self.mesh, @@ -1186,14 +1196,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs): layer_idx=current_idx, )( y, - kwargs["decoder_segment_ids"], - kwargs["decoder_positions"], - kwargs["deterministic"], - kwargs["model_mode"], + *broadcast_args, decoder_input_tokens=kwargs["decoder_input_tokens"], - **layer_call_kwargs, ) - layer.__call__ = functools.partial(original_call, **layer_call_kwargs) return y def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_type, **kwargs): diff --git a/src/maxtext/models/gemma.py b/src/maxtext/models/gemma.py index f73fd12ced..6465ea54d2 100644 --- a/src/maxtext/models/gemma.py +++ b/src/maxtext/models/gemma.py @@ -30,6 +30,7 @@ from maxtext.layers.linears import Dropout, MlpBlock from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import page_manager from maxtext.utils import max_utils @@ -126,8 +127,7 @@ def __call__( deterministic, model_mode, previous_chunk=None, - page_manager=None, - page_state=None, + page_state: None | page_manager.PageState = None, slot=None, kv_cache=None, attention_metadata=None, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 58a0a2db8f..af1a04c562 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -34,6 +34,7 @@ from maxtext.layers.attentions import Attention from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import page_manager from maxtext.utils import max_utils # ----------------------------------------- @@ -138,7 +139,7 @@ def __call__( deterministic, model_mode, previous_chunk=None, - page_state=None, + page_state: None | page_manager.PageState = None, slot=None, kv_cache=None, attention_metadata=None, @@ -258,6 +259,11 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -267,19 +273,19 @@ def __call__( for layer_id in range(cfg.inhomogeneous_layer_cycle_interval): layer_name = f"layers_{layer_id}" layer = getattr(self, layer_name) - y = layer( + y, kv_cache = layer( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) - if cfg.scan_layers: - y = y[0] - if cfg.scan_layers: - return y, None - else: - return y + return y, kv_cache GptOssScannableBlockToLinen = nnx_wrappers.to_linen_class( diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 252dadc768..9961e4048f 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -143,9 +143,9 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - previous_chunk=None, kv_cache=None, attention_metadata=None, ): diff --git a/src/maxtext/models/llama4.py b/src/maxtext/models/llama4.py index c66e80440b..8c4d025948 100644 --- a/src/maxtext/models/llama4.py +++ b/src/maxtext/models/llama4.py @@ -442,9 +442,9 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - previous_chunk=None, kv_cache=None, attention_metadata=None, ): @@ -570,9 +570,11 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - previous_chunk=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -590,6 +592,8 @@ def __call__( previous_chunk=previous_chunk, page_state=page_state, slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) if cfg.scan_layers: y = y[0] diff --git a/src/maxtext/models/mistral.py b/src/maxtext/models/mistral.py index c590a36f85..73e810fc93 100644 --- a/src/maxtext/models/mistral.py +++ b/src/maxtext/models/mistral.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config +from maxtext.inference import page_manager from maxtext.layers import initializers, nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -126,9 +127,9 @@ def __call__( decoder_positions, deterministic, model_mode, - page_state: None | int = None, - slot: None | int = None, previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, kv_cache=None, attention_metadata=None, ): diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index c28020d781..83b31aec7e 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -267,6 +267,11 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=None, + page_state=None, + slot=None, + kv_cache=None, + attention_metadata=None, ): cfg = self.config @@ -282,6 +287,11 @@ def __call__( decoder_positions, deterministic, model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) if cfg.scan_layers: y = y[0] diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index eb15747fc2..5408746a09 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -896,6 +896,8 @@ def __call__( previous_chunk=None, page_state: None | page_manager.PageState = None, slot: None | int = None, + kv_cache=None, + attention_metadata=None, ) -> tuple[Array, None]: """Applies the block of decoder layers to the input carry. @@ -924,6 +926,8 @@ def __call__( previous_chunk, page_state, slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) # The output of the block is the carry for the next scan iteration. @@ -1235,10 +1239,7 @@ def __call__( layer_output = intermediate_inputs + mlp_lnx layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) - if self.config.scan_layers: - return layer_output, None - else: - return layer_output, kv_cache + return layer_output, kv_cache # ----------------------------------------- @@ -1304,10 +1305,7 @@ def __call__( layer_output = intermediate_inputs + mlp_lnx layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) - if self.config.scan_layers: - return layer_output, None - else: - return layer_output, kv_cache + return layer_output, kv_cache class Qwen3OmniMoeVisionPatchMerger(nnx.Module): diff --git a/src/maxtext/models/simple_layer.py b/src/maxtext/models/simple_layer.py index 41ac327281..70be9a5ee1 100644 --- a/src/maxtext/models/simple_layer.py +++ b/src/maxtext/models/simple_layer.py @@ -58,7 +58,17 @@ def __init__( ) def __call__( - self, inputs: jnp.ndarray, positions, segmentation, deterministic, model_mode, previous_chunk=None, page_state=None + self, + inputs: jnp.ndarray, + positions, + segmentation, + deterministic, + model_mode, + previous_chunk=None, + page_state=None, + slot=None, + kv_cache=None, + attention_metadata=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): @@ -121,6 +131,8 @@ def __call__( previous_chunk=None, page_state=None, slot=0, + kv_cache=None, + attention_metadata=None, ): # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple):