Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/maxtext/configs/models/qwen3.5-35b-a3b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# maxtext/configs/models/qwen3.5-35b-moe.yml

decoder_block: "qwen3_5"

# Core Architectural Parameters
base_emb_dim: 2048
base_num_decoder_layers: 40
base_num_query_heads: 16
base_num_kv_heads: 2
head_dim: 256
vocab_size: 248320
normalization_layer_epsilon: 1.0e-6

# MoE Specific Parameters
# Set base_mlp_dim to match base_moe_mlp_dim to pass validation for fully MoE models.
base_mlp_dim: 512
base_moe_mlp_dim: 512
num_experts: 256
shared_experts: 1
num_experts_per_tok: 8
norm_topk_prob: True

# GatedDeltaNet Specific Parameters for Linear Attention (GDN)
inhomogeneous_layer_cycle_interval: 4
gdn_conv_kernel_dim: 4
gdn_key_head_dim: 128
gdn_value_head_dim: 128
gdn_num_key_heads: 16
gdn_num_value_heads: 32
gdn_chunk_size: 64

# RoPE Settings
rope_max_timescale: 10000000
partial_rotary_factor: 0.25

# General Model Settings
enable_dropout: False
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class ProfilerType(str, Enum):
"qwen3-omni-30b-a3b",
"qwen3-custom-30b-a3b",
"qwen3.5-397b-a17b",
"qwen3.5-35b-a3b",
"gpt3-175b",
"gpt3-22b",
"gpt3-6b",
Expand Down
146 changes: 146 additions & 0 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,158 @@
logger = init_logger(__name__)


def _patch_vllm_uses_mrope_for_maxtext() -> None:
"""Suppress vLLM's M-RoPE detection for MaxTextForCausalLM.

vLLM's ``uses_mrope`` returns True whenever the HF config has
``rope_parameters.mrope_section``, which is the case for the Qwen3 family
even for the text-only models MaxText currently serves. That flag then
drives tpu-inference to call ``get_mrope_input_positions`` on the JIT'd
model (which MaxTextForCausalLM doesn't define, so it ends up as None and
the persistent batch manager dereferences it on the first request) and to
precompile mrope-shaped position tensors that our text-only Jax model
can't consume. Force-return False when the HF config is targeting
``MaxTextForCausalLM``. Drop once MaxText supports true multimodal serving
via vLLM, or vLLM gains a per-architecture mrope opt-out.
"""
# pylint: disable=import-outside-toplevel
import vllm.config.model as _vllm_config_model
import vllm.transformers_utils.config as _vllm_config_utils

orig_uses_mrope = _vllm_config_utils.uses_mrope

def _maxtext_uses_mrope(config) -> bool:
architectures = getattr(config, "architectures", None) or []
if "MaxTextForCausalLM" in architectures:
return False
return orig_uses_mrope(config)

_vllm_config_utils.uses_mrope = _maxtext_uses_mrope
# vllm.config.model imported uses_mrope as a local name; rebind that too so
# ModelConfig.uses_mrope picks up the patch.
_vllm_config_model.uses_mrope = _maxtext_uses_mrope


def _patch_tpu_inference_jax_kv_spec_for_maxtext() -> None:
"""Have tpu-inference's JAX kv_cache_spec builder honor ``layer_types == 'linear_attention'``.

Upstream ``tpu_inference.runner.kv_cache_manager.KVCacheManager.get_kv_cache_spec``
has a TODO to unify the hybrid kv-cache path with torchax. Until that lands,
any ``"linear_attention"`` entry in the HF config's ``layer_types`` list silently
becomes a ``FullAttentionSpec`` when no torch attention modules are registered
(the JAX case — MaxTextForCausalLM is an ``nnx.Module``, registers nothing in
``static_forward_context``). That breaks Qwen3-Next / Qwen3.5 served via MaxText:
the GDN layers get paged-attention caches instead of mamba ``(conv_state,
recurrent_state)`` tuples. Replace those slots with a ``MambaSpec`` built from
the model's ``get_mamba_state_shape_from_config``. Drop once tpu-inference's
upstream JAX path supports MambaSpec natively.
"""
# pylint: disable=import-outside-toplevel
import dataclasses

from tpu_inference.runner.kv_cache_manager import KVCacheManager
from vllm.v1.attention.backends.registry import MambaAttentionBackendEnum
from vllm.v1.kv_cache_interface import MambaSpec

orig_get_kv_cache_spec = KVCacheManager.get_kv_cache_spec

def patched(self):
spec = orig_get_kv_cache_spec(self)
text_config = getattr(
self.runner.model_config,
"hf_text_config",
getattr(self.runner.model_config, "hf_config", None),
)
layer_types = getattr(text_config, "layer_types", None)
if not layer_types:
return spec

# Architectures live on the top-level HF config (not the text sub-config).
# Check both so we don't miss it when the override is applied at either layer.
architectures = list(
getattr(self.runner.model_config.hf_config, "architectures", None) or []
) + list(getattr(text_config, "architectures", None) or [])
if "MaxTextForCausalLM" not in architectures:
# Don't disturb foreign architectures sharing this process.
return spec

shapes = MaxTextForCausalLM.get_mamba_state_shape_from_config(self.runner.vllm_config)
dtypes = MaxTextForCausalLM.get_mamba_state_dtype_from_config(self.runner.vllm_config)
block_size = (
self.runner.cache_config.block_size
* self.runner.vllm_config.parallel_config.decode_context_parallel_size
)

# vLLM requires every layer's page_size_bytes to match before grouping
# (vllm.v1.core.kv_cache_utils.unify_kv_cache_spec_page_size). Full-attn
# and mamba state shapes give very different natural page sizes, so we
# mirror tpu-inference's `update_mamba_page_size_padded` (only invoked
# in the torch path) and pad both families to a common
# per-`shared_by`-group footprint.
attn_page_size = next(
(s.page_size_bytes for s in spec.values() if not isinstance(s, MambaSpec)),
None,
)
probe_mamba = MambaSpec(
block_size=block_size, shapes=tuple(shapes), dtypes=tuple(dtypes)
)
mamba_unpadded = probe_mamba.page_size_bytes

num_attn = sum(1 for lt in layer_types if lt != "linear_attention")
num_mamba = sum(1 for lt in layer_types if lt == "linear_attention")
if attn_page_size is None or num_attn == 0 or num_mamba == 0:
# Pure mamba or pure attn; nothing to unify.
uniform = None
else:
mn = min(num_attn, num_mamba)
mx = max(num_attn, num_mamba)
group_size = mx if mx < mn * 1.5 else mn
num_attn_groups = (num_attn + group_size - 1) // group_size
num_mamba_groups = (num_mamba + group_size - 1) // group_size
uniform = int(
num_attn_groups * attn_page_size + num_mamba_groups * mamba_unpadded
)
# Persist the same value tpu-inference's torch path would have stored,
# so the per-layer allocator math at kv_cache_manager.py:700-720 lines up.
self._hybrid_uniform_page_size_bytes = uniform
self.runner.cache_config.mamba_page_size_padded = uniform
for key, s in list(spec.items()):
if not isinstance(s, MambaSpec):
spec[key] = dataclasses.replace(s, page_size_padded=uniform)

replaced = 0
for i, layer_type in enumerate(layer_types):
if layer_type != "linear_attention":
continue
key = f"layer.{i}"
if key not in spec:
continue
spec[key] = MambaSpec(
block_size=block_size,
shapes=tuple(shapes),
dtypes=tuple(dtypes),
page_size_padded=uniform,
mamba_type=MambaAttentionBackendEnum.GDN_ATTN,
)
replaced += 1
logger.info(
"[mt-kv-spec-patch] replaced %d entries with MambaSpec (uniform_page_size=%s)",
replaced,
uniform,
)
return spec

KVCacheManager.get_kv_cache_spec = patched


def register():
"""Register MaxTextForCausalLM model with tpu_inference and vllm.

Note, this function is invoked directly by the vLLM engine during startup. As such,
it leverages vLLM logging to report its status.
"""
logger.info("Registering MaxTextForCausalLM model with tpu_inference and vllm.")
_patch_vllm_uses_mrope_for_maxtext()
_patch_tpu_inference_jax_kv_spec_for_maxtext()
register_model("MaxTextForCausalLM", MaxTextForCausalLM)
logger.info("Successfully registered MaxTextForCausalLM model.")
94 changes: 94 additions & 0 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class AttentionMetadata:


from vllm.config import VllmConfig
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)


def next_power_of_two(x: int) -> int:
Expand Down Expand Up @@ -161,13 +166,77 @@ class MaxTextForCausalLM(nnx.Module):
into the vLLM serving framework, specifically for causal language modeling
tasks. It handles configuration generation, model initialization, and execution
of the decoding step.

Advertises ``is_hybrid = True`` and ``has_inner_state = True`` (vLLM's
``IsHybrid`` / ``HasInnerState`` are runtime_checkable Protocols probed via
``getattr(model, 'is_hybrid', False)`` — so plain class attributes suffice
and we avoid the metaclass conflict that explicit Protocol inheritance has
with ``nnx.Module``). Together with the three ``get_mamba_state_*``
classmethods, this lets tpu-inference treat models with mixed full-attention
+ linear-attention (GDN) blocks — Qwen3-Next / Qwen3.5 — as hybrid and
allocate MambaSpec slots for the GDN layers. The classmethods mirror
upstream vLLM's ``Qwen3NextForCausalLM`` and read ``linear_*`` fields from
the HF text config; they're a no-op for non-hybrid configs since
tpu-inference only consults them when ``layer_types`` is present.
"""

# IsHybrid / HasInnerState Protocol markers (duck-typed via getattr upstream).
is_hybrid: bool = True
has_inner_state: bool = True

# Signal to tpu-inference model_loader that this class manages its own
# JIT-sharded initialization (via create_nnx_model with out_shardings).
# When True, model_loader skips wrapping __init__ in an outer bare @jax.jit,
_self_manages_sharding: bool = True

# IsHybrid / HasInnerState protocol markers. Bodies of the classmethods
# below mirror upstream vLLM's Qwen3NextForCausalLM — they consult
# ``hf_text_config.linear_*`` so they're only meaningful for configs that
# actually carry those fields (Qwen3-Next / Qwen3.5). Non-hybrid configs
# never reach these methods because tpu-inference's spec patch only invokes
# them on layers whose ``layer_types`` entry is ``"linear_attention"``.

@classmethod
def get_mamba_state_shape_from_config(cls, vllm_config: VllmConfig):
"""Conv and recurrent state shapes for the GDN layers of a hybrid model.

Returns the *global* (unsharded) shape. tpu-inference's JAX KVCacheManager
allocates the cache with this shape and shards it via the partition spec
over ``ShardingAxisName.ATTN_HEAD = ('model', 'expert', 'dcp')``. Passing
the real ``tensor_parallel_size`` here would pre-divide on top of that JAX
sharding, leaving the cache half its expected per-device size and causing
a ``conv_state`` vs ``mixed_qkv`` shape mismatch inside ragged_conv1d.
"""
hf = vllm_config.model_config.hf_text_config
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
1, # global shape — JAX sharding does the per-device divide.
hf.linear_num_key_heads,
hf.linear_num_value_heads,
hf.linear_key_head_dim,
hf.linear_value_head_dim,
hf.linear_conv_kernel_dim,
num_spec,
)

@classmethod
def get_mamba_state_dtype_from_config(cls, vllm_config: VllmConfig):
"""Conv and recurrent state dtypes for the GDN layers of a hybrid model."""
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)

@classmethod
def get_mamba_state_copy_func(cls):
"""Per-state copy callables used by MambaPrefixCachingManager."""
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
"""Initializes the MaxTextForCausalLM model.

Expand Down Expand Up @@ -228,6 +297,31 @@ def __call__(
if not isinstance(self.model, nnx.Module):
raise ValueError("Model must be an instance of type nnx.Module.")

# Hybrid models (Qwen3-Next / Qwen3.5) get ``attention_metadata`` as a
# ``dict[layer_name, AttentionMetadata]`` so each layer can pick the
# ``block_tables`` for its own kv_cache_group. Every value shares the
# other fields (input_positions, seq_lens, query_start_loc,
# mamba_state_indices, request_distribution). For the in-model dispatch
# below, GDN layers don't touch block_tables — they index via
# ``mamba_state_indices`` — and all full-attn layers belong to the same
# kv_cache_group so they share one block_tables. Pick a metadata from a
# full-attn (non-linear_attention) layer when possible; otherwise any
# value works.
if isinstance(attention_metadata, dict):
hf_text_config = getattr(
self.cfg, "hf_text_config", getattr(self.cfg, "hf_config", None)
)
layer_types = getattr(hf_text_config, "layer_types", None) or []
attention_metadata_picked = None
for i, lt in enumerate(layer_types):
if lt != "linear_attention":
attention_metadata_picked = attention_metadata.get(f"layer.{i}")
if attention_metadata_picked is not None:
break
if attention_metadata_picked is None:
attention_metadata_picked = next(iter(attention_metadata.values()))
attention_metadata = attention_metadata_picked

# Ensure inputs are at least 2D with a batch dimension
input_ids = jnp.expand_dims(input_ids, axis=1)
input_positions = jnp.expand_dims(attention_metadata.input_positions, axis=1)
Expand Down
15 changes: 5 additions & 10 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,11 @@ def __call__(
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
if kv_caches is not None:
# tpu-inference packs per-layer slots: a single Array for
# full-attention layers and a (conv_state, recurrent_state) tuple
# for GDN/mamba layers. The downstream layer dispatches on shape.
kv_cache = kv_caches[lyr]
elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
# For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches.
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])

if cfg.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
Expand All @@ -1135,11 +1134,7 @@ def __call__(
**layer_call_kwargs,
)
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_caches[lyr] = returned_cache
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = returned_cache[0]
kv_caches["value_cache"][lyr] = returned_cache[1]
kv_caches[lyr] = returned_cache

if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
visual_embeds = deepstack_visual_embeds[lyr]
Expand Down
Loading
Loading