From 7c4da17c14ae7522b654a9327f3c0d8f48e9ee6d Mon Sep 17 00:00:00 2001 From: Nicolas Grande Date: Thu, 14 May 2026 23:53:02 +0000 Subject: [PATCH] adding GDN support for vllm. --- .../configs/models/qwen3.5-35b-a3b.yml | 51 ++++++ src/maxtext/configs/types.py | 1 + .../vllm/maxtext_vllm_adapter/__init__.py | 146 +++++++++++++++ .../vllm/maxtext_vllm_adapter/adapter.py | 94 ++++++++++ src/maxtext/layers/decoders.py | 15 +- src/maxtext/layers/nnx_decoders.py | 18 +- src/maxtext/models/qwen3.py | 166 +++++++++++++++++- src/maxtext/models/qwen3_5.py | 10 +- src/maxtext/utils/globals.py | 2 + tests/unit/attention_test.py | 6 +- tests/unit/qwen3_next_vs_reference_test.py | 6 +- 11 files changed, 475 insertions(+), 40 deletions(-) create mode 100644 src/maxtext/configs/models/qwen3.5-35b-a3b.yml diff --git a/src/maxtext/configs/models/qwen3.5-35b-a3b.yml b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml new file mode 100644 index 0000000000..65e5cfa761 --- /dev/null +++ b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml @@ -0,0 +1,51 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# maxtext/configs/models/qwen3.5-35b-moe.yml + +decoder_block: "qwen3_5" + +# Core Architectural Parameters +base_emb_dim: 2048 +base_num_decoder_layers: 40 +base_num_query_heads: 16 +base_num_kv_heads: 2 +head_dim: 256 +vocab_size: 248320 +normalization_layer_epsilon: 1.0e-6 + +# MoE Specific Parameters +# Set base_mlp_dim to match base_moe_mlp_dim to pass validation for fully MoE models. +base_mlp_dim: 512 +base_moe_mlp_dim: 512 +num_experts: 256 +shared_experts: 1 +num_experts_per_tok: 8 +norm_topk_prob: True + +# GatedDeltaNet Specific Parameters for Linear Attention (GDN) +inhomogeneous_layer_cycle_interval: 4 +gdn_conv_kernel_dim: 4 +gdn_key_head_dim: 128 +gdn_value_head_dim: 128 +gdn_num_key_heads: 16 +gdn_num_value_heads: 32 +gdn_chunk_size: 64 + +# RoPE Settings +rope_max_timescale: 10000000 +partial_rotary_factor: 0.25 + +# General Model Settings +enable_dropout: False diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index bb18a81a5f..dd46553552 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -259,6 +259,7 @@ class ProfilerType(str, Enum): "qwen3-omni-30b-a3b", "qwen3-custom-30b-a3b", "qwen3.5-397b-a17b", + "qwen3.5-35b-a3b", "gpt3-175b", "gpt3-22b", "gpt3-6b", diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py index 216737a7fc..f4d4408935 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py @@ -22,6 +22,150 @@ logger = init_logger(__name__) +def _patch_vllm_uses_mrope_for_maxtext() -> None: + """Suppress vLLM's M-RoPE detection for MaxTextForCausalLM. + + vLLM's ``uses_mrope`` returns True whenever the HF config has + ``rope_parameters.mrope_section``, which is the case for the Qwen3 family + even for the text-only models MaxText currently serves. That flag then + drives tpu-inference to call ``get_mrope_input_positions`` on the JIT'd + model (which MaxTextForCausalLM doesn't define, so it ends up as None and + the persistent batch manager dereferences it on the first request) and to + precompile mrope-shaped position tensors that our text-only Jax model + can't consume. Force-return False when the HF config is targeting + ``MaxTextForCausalLM``. Drop once MaxText supports true multimodal serving + via vLLM, or vLLM gains a per-architecture mrope opt-out. + """ + # pylint: disable=import-outside-toplevel + import vllm.config.model as _vllm_config_model + import vllm.transformers_utils.config as _vllm_config_utils + + orig_uses_mrope = _vllm_config_utils.uses_mrope + + def _maxtext_uses_mrope(config) -> bool: + architectures = getattr(config, "architectures", None) or [] + if "MaxTextForCausalLM" in architectures: + return False + return orig_uses_mrope(config) + + _vllm_config_utils.uses_mrope = _maxtext_uses_mrope + # vllm.config.model imported uses_mrope as a local name; rebind that too so + # ModelConfig.uses_mrope picks up the patch. + _vllm_config_model.uses_mrope = _maxtext_uses_mrope + + +def _patch_tpu_inference_jax_kv_spec_for_maxtext() -> None: + """Have tpu-inference's JAX kv_cache_spec builder honor ``layer_types == 'linear_attention'``. + + Upstream ``tpu_inference.runner.kv_cache_manager.KVCacheManager.get_kv_cache_spec`` + has a TODO to unify the hybrid kv-cache path with torchax. Until that lands, + any ``"linear_attention"`` entry in the HF config's ``layer_types`` list silently + becomes a ``FullAttentionSpec`` when no torch attention modules are registered + (the JAX case — MaxTextForCausalLM is an ``nnx.Module``, registers nothing in + ``static_forward_context``). That breaks Qwen3-Next / Qwen3.5 served via MaxText: + the GDN layers get paged-attention caches instead of mamba ``(conv_state, + recurrent_state)`` tuples. Replace those slots with a ``MambaSpec`` built from + the model's ``get_mamba_state_shape_from_config``. Drop once tpu-inference's + upstream JAX path supports MambaSpec natively. + """ + # pylint: disable=import-outside-toplevel + import dataclasses + + from tpu_inference.runner.kv_cache_manager import KVCacheManager + from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum + from vllm.v1.kv_cache_interface import MambaSpec + + orig_get_kv_cache_spec = KVCacheManager.get_kv_cache_spec + + def patched(self): + spec = orig_get_kv_cache_spec(self) + text_config = getattr( + self.runner.model_config, + "hf_text_config", + getattr(self.runner.model_config, "hf_config", None), + ) + layer_types = getattr(text_config, "layer_types", None) + if not layer_types: + return spec + + # Architectures live on the top-level HF config (not the text sub-config). + # Check both so we don't miss it when the override is applied at either layer. + architectures = list( + getattr(self.runner.model_config.hf_config, "architectures", None) or [] + ) + list(getattr(text_config, "architectures", None) or []) + if "MaxTextForCausalLM" not in architectures: + # Don't disturb foreign architectures sharing this process. + return spec + + shapes = MaxTextForCausalLM.get_mamba_state_shape_from_config(self.runner.vllm_config) + dtypes = MaxTextForCausalLM.get_mamba_state_dtype_from_config(self.runner.vllm_config) + block_size = ( + self.runner.cache_config.block_size + * self.runner.vllm_config.parallel_config.decode_context_parallel_size + ) + + # vLLM requires every layer's page_size_bytes to match before grouping + # (vllm.v1.core.kv_cache_utils.unify_kv_cache_spec_page_size). Full-attn + # and mamba state shapes give very different natural page sizes, so we + # mirror tpu-inference's `update_mamba_page_size_padded` (only invoked + # in the torch path) and pad both families to a common + # per-`shared_by`-group footprint. + attn_page_size = next( + (s.page_size_bytes for s in spec.values() if not isinstance(s, MambaSpec)), + None, + ) + probe_mamba = MambaSpec( + block_size=block_size, shapes=tuple(shapes), dtypes=tuple(dtypes) + ) + mamba_unpadded = probe_mamba.page_size_bytes + + num_attn = sum(1 for lt in layer_types if lt != "linear_attention") + num_mamba = sum(1 for lt in layer_types if lt == "linear_attention") + if attn_page_size is None or num_attn == 0 or num_mamba == 0: + # Pure mamba or pure attn; nothing to unify. + uniform = None + else: + mn = min(num_attn, num_mamba) + mx = max(num_attn, num_mamba) + group_size = mx if mx < mn * 1.5 else mn + num_attn_groups = (num_attn + group_size - 1) // group_size + num_mamba_groups = (num_mamba + group_size - 1) // group_size + uniform = int( + num_attn_groups * attn_page_size + num_mamba_groups * mamba_unpadded + ) + # Persist the same value tpu-inference's torch path would have stored, + # so the per-layer allocator math at kv_cache_manager.py:700-720 lines up. + self._hybrid_uniform_page_size_bytes = uniform + self.runner.cache_config.mamba_page_size_padded = uniform + for key, s in list(spec.items()): + if not isinstance(s, MambaSpec): + spec[key] = dataclasses.replace(s, page_size_padded=uniform) + + replaced = 0 + for i, layer_type in enumerate(layer_types): + if layer_type != "linear_attention": + continue + key = f"layer.{i}" + if key not in spec: + continue + spec[key] = MambaSpec( + block_size=block_size, + shapes=tuple(shapes), + dtypes=tuple(dtypes), + page_size_padded=uniform, + mamba_type=MambaAttentionBackendEnum.GDN_ATTN, + ) + replaced += 1 + logger.info( + "[mt-kv-spec-patch] replaced %d entries with MambaSpec (uniform_page_size=%s)", + replaced, + uniform, + ) + return spec + + KVCacheManager.get_kv_cache_spec = patched + + def register(): """Register MaxTextForCausalLM model with tpu_inference and vllm. @@ -29,5 +173,7 @@ def register(): it leverages vLLM logging to report its status. """ logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.") + _patch_vllm_uses_mrope_for_maxtext() + _patch_tpu_inference_jax_kv_spec_for_maxtext() register_model("MaxTextForCausalLM", MaxTextForCausalLM) logger.info("Successfully registered MaxTextForCausalLM model.") diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 07231f965e..3f386ce5f1 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -39,6 +39,11 @@ class AttentionMetadata: from vllm.config import VllmConfig +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFuncCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) def next_power_of_two(x: int) -> int: @@ -161,13 +166,77 @@ class MaxTextForCausalLM(nnx.Module): into the vLLM serving framework, specifically for causal language modeling tasks. It handles configuration generation, model initialization, and execution of the decoding step. + + Advertises ``is_hybrid = True`` and ``has_inner_state = True`` (vLLM's + ``IsHybrid`` / ``HasInnerState`` are runtime_checkable Protocols probed via + ``getattr(model, 'is_hybrid', False)`` — so plain class attributes suffice + and we avoid the metaclass conflict that explicit Protocol inheritance has + with ``nnx.Module``). Together with the three ``get_mamba_state_*`` + classmethods, this lets tpu-inference treat models with mixed full-attention + + linear-attention (GDN) blocks — Qwen3-Next / Qwen3.5 — as hybrid and + allocate MambaSpec slots for the GDN layers. The classmethods mirror + upstream vLLM's ``Qwen3NextForCausalLM`` and read ``linear_*`` fields from + the HF text config; they're a no-op for non-hybrid configs since + tpu-inference only consults them when ``layer_types`` is present. """ + # IsHybrid / HasInnerState Protocol markers (duck-typed via getattr upstream). + is_hybrid: bool = True + has_inner_state: bool = True + # Signal to tpu-inference model_loader that this class manages its own # JIT-sharded initialization (via create_nnx_model with out_shardings). # When True, model_loader skips wrapping __init__ in an outer bare @jax.jit, _self_manages_sharding: bool = True + # IsHybrid / HasInnerState protocol markers. Bodies of the classmethods + # below mirror upstream vLLM's Qwen3NextForCausalLM — they consult + # ``hf_text_config.linear_*`` so they're only meaningful for configs that + # actually carry those fields (Qwen3-Next / Qwen3.5). Non-hybrid configs + # never reach these methods because tpu-inference's spec patch only invokes + # them on layers whose ``layer_types`` entry is ``"linear_attention"``. + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config: VllmConfig): + """Conv and recurrent state shapes for the GDN layers of a hybrid model. + + Returns the *global* (unsharded) shape. tpu-inference's JAX KVCacheManager + allocates the cache with this shape and shards it via the partition spec + over ``ShardingAxisName.ATTN_HEAD = ('model', 'expert', 'dcp')``. Passing + the real ``tensor_parallel_size`` here would pre-divide on top of that JAX + sharding, leaving the cache half its expected per-device size and causing + a ``conv_state`` vs ``mixed_qkv`` shape mismatch inside ragged_conv1d. + """ + hf = vllm_config.model_config.hf_text_config + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + 1, # global shape — JAX sharding does the per-device divide. + hf.linear_num_key_heads, + hf.linear_num_value_heads, + hf.linear_key_head_dim, + hf.linear_value_head_dim, + hf.linear_conv_kernel_dim, + num_spec, + ) + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config: VllmConfig): + """Conv and recurrent state dtypes for the GDN layers of a hybrid model.""" + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + + @classmethod + def get_mamba_state_copy_func(cls): + """Per-state copy callables used by MambaPrefixCachingManager.""" + return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func() + def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh): """Initializes the MaxTextForCausalLM model. @@ -228,6 +297,31 @@ def __call__( if not isinstance(self.model, nnx.Module): raise ValueError("Model must be an instance of type nnx.Module.") + # Hybrid models (Qwen3-Next / Qwen3.5) get ``attention_metadata`` as a + # ``dict[layer_name, AttentionMetadata]`` so each layer can pick the + # ``block_tables`` for its own kv_cache_group. Every value shares the + # other fields (input_positions, seq_lens, query_start_loc, + # mamba_state_indices, request_distribution). For the in-model dispatch + # below, GDN layers don't touch block_tables — they index via + # ``mamba_state_indices`` — and all full-attn layers belong to the same + # kv_cache_group so they share one block_tables. Pick a metadata from a + # full-attn (non-linear_attention) layer when possible; otherwise any + # value works. + if isinstance(attention_metadata, dict): + hf_text_config = getattr( + self.cfg, "hf_text_config", getattr(self.cfg, "hf_config", None) + ) + layer_types = getattr(hf_text_config, "layer_types", None) or [] + attention_metadata_picked = None + for i, lt in enumerate(layer_types): + if lt != "linear_attention": + attention_metadata_picked = attention_metadata.get(f"layer.{i}") + if attention_metadata_picked is not None: + break + if attention_metadata_picked is None: + attention_metadata_picked = next(iter(attention_metadata.values())) + attention_metadata = attention_metadata_picked + # Ensure inputs are at least 2D with a batch dimension input_ids = jnp.expand_dims(input_ids, axis=1) input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1) diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..c885c8bc1a 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1107,12 +1107,11 @@ def __call__( if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): layer_kwargs = {"layer_idx": lyr} kv_cache = None - if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if kv_caches is not None: + # tpu-inference packs per-layer slots: a single Array for + # full-attention layers and a (conv_state, recurrent_state) tuple + # for GDN/mamba layers. The downstream layer dispatches on shape. kv_cache = kv_caches[lyr] - elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - # For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches. - if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} @@ -1135,11 +1134,7 @@ def __call__( **layer_call_kwargs, ) if kv_caches is not None and returned_cache is not None: - if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - kv_caches[lyr] = returned_cache - elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_caches["key_cache"][lyr] = returned_cache[0] - kv_caches["value_cache"][lyr] = returned_cache[1] + kv_caches[lyr] = returned_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..6da7a20e11 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1244,13 +1244,10 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): for lyr, layer in enumerate(self.layers): graphdef, state = nnx.split(layer) if kv_caches is not None: - if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: - if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) - else: - kv_cache = None - else: - kv_cache = kv_caches[lyr] + # tpu-inference packs per-layer slots: a single Array for + # full-attention layers and a (conv_state, recurrent_state) tuple + # for GDN/mamba layers. The downstream layer dispatches on shape. + kv_cache = kv_caches[lyr] else: kv_cache = None @@ -1262,12 +1259,7 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: - if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: - if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_caches["key_cache"][lyr] = kv_cache[0] - kv_caches["value_cache"][lyr] = kv_cache[1] - else: - kv_caches[lyr] = kv_cache + kv_caches[lyr] = kv_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..f5376130a5 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -381,13 +381,25 @@ class Qwen3NextGatedDeltaNet(nnx.Module): 2. output = Linear_out(y) """ - def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs): + def __init__( + self, + config: Config, + dtype: DType = jnp.float32, + model_mode: str = MODEL_MODE_TRAIN, + mesh: Mesh | None = None, + *, + rngs: nnx.Rngs, + ): """ Args: config: MaxText configuration object. + mesh: Optional device mesh. Required only when serving via tpu-inference + (``config.attention == "vllm_rpa"``); the tpu-inference GDN kernel uses + it for ``jax.shard_map``. rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. """ self.config = config + self.mesh = mesh cfg = self.config in_features = cfg.emb_dim @@ -401,7 +413,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = conv_kernel_size = cfg.gdn_conv_kernel_dim self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads - if model_mode != MODEL_MODE_TRAIN: + if model_mode != MODEL_MODE_TRAIN and cfg.attention != "vllm_rpa": self.cache = kvcache.GatedDeltaNetCache( batch=config.per_device_batch_size, num_heads=self.num_v_heads, @@ -477,10 +489,17 @@ def __call__( model_mode: str = MODEL_MODE_TRAIN, kv_cache=None, decoder_segment_ids: None | Array = None, + attention_metadata: None | Any = None, **kwargs, - ) -> Array: + ) -> tuple[Array, tuple[Array, Array] | None]: # hidden_states: (B, S, E) cfg = self.config + + # vLLM/tpu-inference serving path: cache is externally managed and the + # tpu-inference GDN kernel consumes ragged scheduling metadata. + if cfg.attention == "vllm_rpa" and attention_metadata is not None and model_mode != MODEL_MODE_TRAIN: + return self._forward_serve_vllm(hidden_states, kv_cache, attention_metadata) + batch, seq_len, _ = hidden_states.shape # ========================================================================= @@ -669,7 +688,136 @@ def extract_state(c_in, v_len): # Final output shape: (B, S, E) output = self.out_proj(gated_output) - return output + return output, None + + def _forward_serve_vllm( + self, + hidden_states: Array, + kv_cache: tuple[Array, Array], + attention_metadata: Any, + ) -> tuple[Array, tuple[Array, Array]]: + """vLLM/tpu-inference GDN path. + + Externally-managed ``(conv_state, recurrent_state)`` cache flows in via + ``kv_cache`` and back out via the return value. Inputs are flattened to + the ragged ``(num_tokens, dim)`` layout expected by tpu-inference, and + weights are reordered from ``[Q|K|V]`` to the per-shard interleaved + layout the kernel requires (see the pytorch reference in + ``tpu_inference.layers.common.gdn_attention``). + """ + try: + # pylint: disable=import-outside-toplevel + # pytype: disable=import-error + from tpu_inference.layers.common.gdn_attention import GdnAttentionConfig, run_jax_gdn_attention + from tpu_inference.layers.common.sharding import ShardingAxisName + from tpu_inference.layers.common.utils import reorder_concatenated_tensor_for_sharding + from tpu_inference.utils import get_mesh_shape_product + except ImportError as e: + raise ImportError( + "GDN vLLM serving requires the tpu-inference package. Install it with `pip install vllm-tpu`." + ) from e + + cfg = self.config + batch, seq_len, _ = hidden_states.shape + num_tokens = batch * seq_len + + # ========================================================================= + # STEP A: Input projections (same as training path, but ragged-flattened). + # ========================================================================= + qkvz = self.in_proj_qkvz(hidden_states) + ba = self.in_proj_ba(hidden_states) + + new_shape_qkvz = ( + batch, + seq_len, + self.num_k_heads, + 2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head, + ) + mixed_qkvz = qkvz.reshape(new_shape_qkvz) + split_indices_qkvz = [ + self.head_k_dim, + 2 * self.head_k_dim, + 2 * self.head_k_dim + self.v_heads_per_k_head * self.head_v_dim, + ] + query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3) + # z: (B, S, H_v, D_v) — kept multi-head for the post-kernel gated norm. + z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + mixed_ba = ba.reshape(batch, seq_len, self.num_k_heads, 2 * self.v_heads_per_k_head) + b_raw, a_raw = jnp.split(mixed_ba, [self.v_heads_per_k_head], axis=3) + # b, a: (num_tokens, H_v) + b = b_raw.reshape(num_tokens, self.num_v_heads) + a = a_raw.reshape(num_tokens, self.num_v_heads) + + # Flat [Q | K | V] mixed_qkv: (num_tokens, 2*key_dim + value_dim) + q_flat = query.reshape(num_tokens, self.key_dim) + k_flat = key.reshape(num_tokens, self.key_dim) + v_flat = value_raw.reshape(num_tokens, self.value_dim) + mixed_qkv = jnp.concatenate([q_flat, k_flat, v_flat], axis=-1) + + # ========================================================================= + # STEP B: Reorder mixed_qkv and conv weights for per-shard interleaved + # layout, then call tpu-inference's fused conv + ragged delta-rule kernel. + # ========================================================================= + tp_size = get_mesh_shape_product(self.mesh, ShardingAxisName.ATTN_HEAD) + + mixed_qkv = reorder_concatenated_tensor_for_sharding( + mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], tp_size, -1 + ) + + # nnx.Conv kernel layout is (kernel_size, in_features // groups, out_features). + # With feature_group_count == conv_dim this is (kernel_size, 1, conv_dim). + # tpu-inference expects (conv_dim, 1, kernel_size). + conv_weight = jnp.transpose(self.conv1d.kernel.value, (2, 1, 0)) + conv_weight = reorder_concatenated_tensor_for_sharding( + conv_weight, [self.key_dim, self.key_dim, self.value_dim], tp_size, 0 + ) + + conv_state, recurrent_state = kv_cache + + # tpu-inference's JAX AttentionMetadata carries the full (max_num_seqs,) + # tensors. Padded slots are masked by the kernel via request_distribution + # and zero-seq-len entries — no manual ``padded_num_reqs`` slicing needed + # here (that field only exists on the torch path's wrapper metadata). + # Same shape contract as ``attentions.forward_serve_vllm`` passing into + # ``rpa_ops`` for full attention. + state_indices = attention_metadata.mamba_state_indices.astype(jnp.int32) + query_start_loc = attention_metadata.query_start_loc + seq_lens = attention_metadata.seq_lens + + (new_conv_state, new_recurrent_state), output = run_jax_gdn_attention( + mixed_qkv, + b, + a, + conv_state, + recurrent_state, + conv_weight, + None, # conv_bias: MaxText conv1d uses use_bias=False. + self.A_log.value, + self.dt_bias.value, + state_indices, + query_start_loc, + attention_metadata.request_distribution, + seq_lens, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + cfg.gdn_conv_kernel_dim, + mesh=self.mesh, + config=GdnAttentionConfig(), + ) + + # ========================================================================= + # STEP C: Reshape kernel output, apply gated norm with z, project out. + # ========================================================================= + # output: (num_tokens, n_v * d_v) -> (B, S, H_v, D_v) + output = output.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + gated_output = self.norm(output, z) + gated_output = gated_output.reshape(batch, seq_len, -1) + output = self.out_proj(gated_output) + + return output, (new_conv_state, new_recurrent_state) class Qwen3NextFullAttention(nnx.Module): @@ -986,7 +1134,9 @@ def __init__( rngs=rngs, ) else: - self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) + self.attention = Qwen3NextGatedDeltaNet( + config=cfg, dtype=cfg.dtype, model_mode=model_mode, mesh=self.mesh, rngs=rngs + ) # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( @@ -1034,13 +1184,13 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3NextGatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, + attention_metadata=attention_metadata, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..14b8f65608 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -159,7 +159,9 @@ def __init__( rngs=rngs, ) else: - self.attention = Qwen3_5GatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) + self.attention = Qwen3_5GatedDeltaNet( + config=cfg, dtype=cfg.dtype, model_mode=model_mode, mesh=self.mesh, rngs=rngs + ) # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( @@ -207,13 +209,13 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3_5GatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3_5GatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, + attention_metadata=attention_metadata, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index db41a116f2..9b92eb642f 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -77,6 +77,8 @@ "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "qwen3.5-397b-a17b": "Qwen/Qwen3.5-397B-A17B", + "qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", "olmo3-7b": "allenai/Olmo-3-7B-Instruct", diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 2e8d470cd6..67783c2753 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1931,7 +1931,7 @@ def test_autoregression(self): ) # 3. Full / Train mode - gdn_full = gdn( + gdn_full, _ = gdn( lnx, model_mode=MODEL_MODE_TRAIN, ) @@ -1939,7 +1939,7 @@ def test_autoregression(self): # 4. Prefill mode lnx_prefill = lnx[:, 0:prefill_length, :] - gdn_prefill = gdn( + gdn_prefill, _ = gdn( lnx_prefill, model_mode=MODEL_MODE_PREFILL, ) @@ -1952,7 +1952,7 @@ def test_autoregression(self): for idx in range(prefill_length, decode_total_length): lnx_idx = lnx[:, idx : idx + 1, :] - gdn_idx = gdn( + gdn_idx, _ = gdn( lnx_idx, model_mode=MODEL_MODE_AUTOREGRESSIVE, ) diff --git a/tests/unit/qwen3_next_vs_reference_test.py b/tests/unit/qwen3_next_vs_reference_test.py index f22cf72a17..7cb5497ec4 100644 --- a/tests/unit/qwen3_next_vs_reference_test.py +++ b/tests/unit/qwen3_next_vs_reference_test.py @@ -909,7 +909,8 @@ def test_gated_delta_net_structure(self): @jax.jit def run_jax(hidden_states): """Runs the JAX GatedDeltaNet model.""" - return jax_model(hidden_states) + output, _ = jax_model(hidden_states) + return output output_jax = run_jax(hidden_states_jax) @@ -1070,7 +1071,8 @@ def test_gated_delta_net_full(self): @jax.jit def run_jax(x): """Runs the JAX GatedDeltaNet model.""" - return jax_model(x) + output, _ = jax_model(x) + return output actual_output = run_jax(hidden_states_jax)