diff --git a/src/maxtext/configs/models/qwen3.5-35b-a3b.yml b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml new file mode 100644 index 0000000000..44461b8040 --- /dev/null +++ b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml @@ -0,0 +1,51 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# maxtext/configs/models/qwen3.5-35b-moe.yml + +decoder_block: "qwen3_5" + +# Core Architectural Parameters +base_emb_dim: 2048 +base_num_decoder_layers: 40 +base_num_query_heads: 16 +base_num_kv_heads: 2 +head_dim: 256 +vocab_size: 248320 +normalization_layer_epsilon: 1.0e-6 + +# MoE Specific Parameters +# Set base_mlp_dim to match base_moe_mlp_dim to pass validation for fully MoE models. +base_mlp_dim: 512 +base_moe_mlp_dim: 512 +num_experts: 256 +shared_experts: 1 +num_experts_per_tok: 8 +norm_topk_prob: True + +# GatedDeltaNet Specific Parameters for Linear Attention (GDN) +inhomogeneous_layer_cycle_interval: 4 +gdn_conv_kernel_dim: 4 +gdn_key_head_dim: 128 +gdn_value_head_dim: 128 +gdn_num_key_heads: 16 +gdn_num_value_heads: 32 +gdn_chunk_size: 64 + +# RoPE Settings +rope_max_timescale: 10000000 +partial_rotary_factor: 0.25 + +# General Model Settings +enable_dropout: False \ No newline at end of file diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..a2f8af4f8b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -260,6 +260,7 @@ class ProfilerType(str, Enum): "qwen3-omni-30b-a3b", "qwen3-custom-30b-a3b", "qwen3.5-397b-a17b", + "qwen3.5-35b-a3b", "gpt3-175b", "gpt3-22b", "gpt3-6b", diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index 2162a9d9eb..515ffa7286 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -174,6 +174,9 @@ def kv_cache_as_linen( key_axis_order: AxisIdxes = (2, 0, 1, 3), use_chunked_prefill: bool = False, model_mode: str = MODEL_MODE_PREFILL, + is_gdn: bool = False, + conv_kernel_size: int = 0, + conv_dim: int = 0, name: str | None = None, ): """Initializes the KVCache module and returns it as a Linen module. @@ -224,6 +227,9 @@ def kv_cache_as_linen( key_axis_order=key_axis_order, use_chunked_prefill=use_chunked_prefill, model_mode=model_mode, + is_gdn=is_gdn, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, metadata_fn=variable_to_logically_partitioned, name=name, abstract_init=False, @@ -265,6 +271,9 @@ def __init__( key_axis_order: AxisIdxes = (2, 0, 1, 3), use_chunked_prefill: bool = False, model_mode: str = MODEL_MODE_PREFILL, + is_gdn: bool = False, # <-- ADDED + conv_kernel_size: int = 0, # <-- ADDED + conv_dim: int = 0, *, # Not used in KVCache but passed in by nnx_wrappers.to_linen. # TODO: Remove when bridge no longer needed @@ -314,6 +323,9 @@ def __init__( self.key_axis_order = key_axis_order self.model_mode = model_mode self.use_chunked_prefill = use_chunked_prefill + self.is_gdn = is_gdn # <-- ADDED + self.conv_kernel_size = conv_kernel_size # <-- ADDED + self.conv_dim = conv_dim if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): self._initialize_prefill_caches(model_mode) @@ -349,6 +361,14 @@ def _get_cache_scale_logical_shape(self, heads, cache_length): def _initialize_prefill_caches(self, model_mode): """Get a shaped abstraction of the state""" + if self.is_gdn: + self.cached_prefill_key = None + self.cached_prefill_value = None + self.cache_prefill_segment_id = None + self.cached_prefill_key_scale = None + self.cached_prefill_value_scale = None + return + cache_length = self.max_prefill_length dtype = self._get_cached_kv_dtype() @@ -411,6 +431,45 @@ def _initialize_ar_cache_vars(self, model_mode): """get ar cache vars""" dtype = self._get_cached_kv_dtype() + + # Pre-allocate fixed-size GDN states to standard cache containers + if self.is_gdn: + cache_batch_axis_name = CACHE_BATCH_PREFILL if model_mode == MODEL_MODE_PREFILL else CACHE_BATCH + + # 1. Map Recurrent State matrix directly to cached_ar_key + # Shape: [batch, key_heads, key_head_size, value_head_size] + self.cached_ar_key = nnx.Cache( + jnp.zeros((self.batch, self.key_heads, self.key_head_size, self.value_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, CACHE_HEADS, None, None), + ) + + # 2. Map 1D Conv State directly to cached_ar_value + # Shape: [batch, conv_kernel_size - 1, conv_dim] + self.cached_ar_value = nnx.Cache( + jnp.zeros((self.batch, self.conv_kernel_size - 1, self.conv_dim), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None), + ) + + # Initialize required dummy variables to satisfy uniform engine inspection loops + segment_id_axis_names = ( + (CACHE_BATCH_PREFILL, CACHE_SEQUENCE) if model_mode == MODEL_MODE_PREFILL else (CACHE_BATCH, CACHE_SEQUENCE) + ) + self.cache_ar_segment_id = nnx.Cache( + jnp.zeros((self.batch, 1), dtype=jnp.int32), + out_sharding=segment_id_axis_names, + ) + self.cached_ar_lengths = nnx.Cache( + jnp.zeros((self.batch,), dtype=jnp.int32), + out_sharding=(CACHE_BATCH,), + ) + self.cached_ar_key_scale = None + self.cached_ar_value_scale = None + self.cache_ar_index = nnx.Cache( + jnp.zeros((1,), dtype=jnp.int32), + out_sharding=(), + ) + return + if self.max_target_length <= self.max_prefill_length: raise ValueError( f"max_target_length: {self.max_target_length} should be greater than max_prefill_length:" @@ -489,6 +548,22 @@ def _initialize_ar_cache_vars(self, model_mode): out_sharding=(), ) + def get_gdn_states(self) -> tuple[jax.Array, jax.Array]: + """Retrieves the recurrent state and conv state for GDN layers.""" + assert self.is_gdn, "get_gdn_states called on a non-GDN cache object." + return self.cached_ar_key.get_value(), self.cached_ar_value.get_value() + + def update_gdn_states(self, new_recurrent_state: Array, new_conv_state: Array) -> None: + """Updates the recurrent state and conv state for GDN layers.""" + assert self.is_gdn, "update_gdn_states called on a non-GDN cache object." + cache_batch_axis_name = CACHE_BATCH_PREFILL if self.model_mode == MODEL_MODE_PREFILL else CACHE_BATCH + + # Overwrite running states while enforcing designated logical partitioning rules + self.cached_ar_key.set_value( + nn.with_logical_constraint(new_recurrent_state, (cache_batch_axis_name, CACHE_HEADS, None, None)) + ) + self.cached_ar_value.set_value(nn.with_logical_constraint(new_conv_state, (cache_batch_axis_name, None, None))) + def _get_ar_cache_vars(self): return self.ar_key_vars, self.ar_value_vars, self.cache_ar_segment_id, self.cache_ar_index, self.cached_ar_lengths @@ -880,76 +955,6 @@ def __call__( raise ValueError(f"Model Mode isn't supported! {model_mode=}") -class GatedDeltaNetCache(BaseCache): - """Cache for Linear Attention (Gated Delta Net). - - Stores the fixed-size recurrent state and the sliding window state for convolution. - """ - - def __init__( - self, - batch: int, - num_heads: int, - k_head_dim: int, - v_head_dim: int, - conv_kernel_size: int, - conv_dim: int, - dtype: DType, - cache_batch_axis_name: str = CACHE_BATCH, - cache_heads_axis_name: str = CACHE_HEADS, - ): - super().__init__() - self.batch = batch - self.dtype = dtype - - # 1. Recurrent State (S) for the Delta Rule - # Shape: [Batch, Heads, K_Dim, V_Dim] - # We maintain the running state matrix. - self.recurrent_state = nnx.Cache( - jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype), - # Sharding: Batch, Heads, None (K), None (V) - out_sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None), - ) - - # 2. Convolution State for the 1D Conv - # Shape: [Batch, Kernel_Size - 1, Conv_Dim] - # We store the last (K-1) inputs to perform the sliding window conv during decoding. - self.conv_state = nnx.Cache( - jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype), - # Sharding: Batch, None (Time), None (Dim) - out_sharding=(cache_batch_axis_name, None, None), - ) - - def __call__(self): - """Returns the cache variables for the layer to use.""" - return self - - -def gated_delta_net_cache_as_linen( - *, - batch: int, - num_heads: int, - head_dim: int, - conv_kernel_size: int, - conv_dim: int, - dtype: DType, - name: str | None = None, -): - """Initializes the GatedDeltaNetCache and returns it as a Linen module.""" - return nnx_wrappers.to_linen( - GatedDeltaNetCache, - batch=batch, - num_heads=num_heads, - head_dim=head_dim, - conv_kernel_size=conv_kernel_size, - conv_dim=conv_dim, - dtype=dtype, - metadata_fn=variable_to_logically_partitioned, - name=name, - abstract_init=False, - ) - - def mla_kv_cache_as_linen( *, max_prefill_length: int, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..584a47eab1 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1107,12 +1107,8 @@ def __call__( if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): layer_kwargs = {"layer_idx": lyr} kv_cache = None - if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if kv_caches is not None: kv_cache = kv_caches[lyr] - elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - # For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches. - if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} @@ -1135,11 +1131,7 @@ def __call__( **layer_call_kwargs, ) if kv_caches is not None and returned_cache is not None: - if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - kv_caches[lyr] = returned_cache - elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_caches["key_cache"][lyr] = returned_cache[0] - kv_caches["value_cache"][lyr] = returned_cache[1] + kv_caches[lyr] = returned_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..46566f2809 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -402,15 +402,28 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads if model_mode != MODEL_MODE_TRAIN: - self.cache = kvcache.GatedDeltaNetCache( - batch=config.per_device_batch_size, - num_heads=self.num_v_heads, - k_head_dim=self.head_k_dim, - v_head_dim=self.head_v_dim, - conv_kernel_size=self.config.gdn_conv_kernel_dim, - conv_dim=conv_dim, + # Use global batch size so cache allocation matches SPMD tracing shapes + global_batch, _ = max_utils.get_batch_seq_len_for_mode(config, model_mode) + + self.cache = kvcache.KVCache( + max_prefill_length=cfg.max_prefill_predict_length, + max_target_length=cfg.max_target_length, + batch=global_batch, # Directly matches the global trace of 8 + key_seq_len=1, + value_seq_len=1, + key_heads=self.num_v_heads, + value_heads=self.num_v_heads, + key_head_size=self.head_k_dim, + value_head_size=self.head_v_dim, dtype=dtype, + is_gdn=True, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, + model_mode=model_mode, + rngs=rngs, ) + else: + self.cache = None # Submodule instantiations self.in_proj_qkvz = DenseGeneral( @@ -483,6 +496,8 @@ def __call__( cfg = self.config batch, seq_len, _ = hidden_states.shape + active_cache = kv_cache if kv_cache is not None else self.cache + # ========================================================================= # STEP A: Input Projections # ========================================================================= @@ -554,33 +569,43 @@ def __call__( conv_kernel_size = self.config.gdn_conv_kernel_dim conv_state = None - if model_mode != MODEL_MODE_TRAIN: - # Retrieve state from self.cache - conv_state = self.cache.conv_state[...] + recurrent_state = None + next_conv_state = None + if model_mode != MODEL_MODE_TRAIN and active_cache is not None: + recurrent_state, conv_state = active_cache.get_gdn_states() + + # Safely align conv_state batch dimension (Broadcast, Slice, or Pad) if conv_state.shape[0] != batch: - # Assumes zero-initialized state for testing if conv_state.shape[0] == 1: conv_state = jnp.broadcast_to(conv_state, (batch,) + conv_state.shape[1:]) + elif conv_state.shape[0] < batch: + pad_amt = batch - conv_state.shape[0] + conv_state = jnp.pad(conv_state, ((0, pad_amt), (0, 0), (0, 0))) else: conv_state = conv_state[:batch] - # Concatenate previous state with new input + # Safely align recurrent_state batch dimension + if recurrent_state.shape[0] != batch: + if recurrent_state.shape[0] == 1: + recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) + elif recurrent_state.shape[0] < batch: + pad_amt = batch - recurrent_state.shape[0] + recurrent_state = jnp.pad(recurrent_state, ((0, pad_amt), (0, 0), (0, 0), (0, 0))) + else: + recurrent_state = recurrent_state[:batch] + conv_input = jnp.concatenate([conv_state, qkv], axis=1) if decoder_segment_ids is not None: - valid_lens = jnp.sum(decoder_segment_ids != 0, axis=1) # Shape: (B,) + valid_lens = jnp.sum(decoder_segment_ids != 0, axis=1) def extract_state(c_in, v_len): return jax.lax.dynamic_slice_in_dim(c_in, v_len, conv_kernel_size - 1, axis=0) - new_conv_state = jax.vmap(extract_state)(conv_input, valid_lens) + next_conv_state = jax.vmap(extract_state)(conv_input, valid_lens) else: - new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] - - # Update self.cache in place - self.cache.conv_state.set_value(new_conv_state) + next_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] else: - # Train: pad with zeros conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0))) # Perform the convolution. @@ -592,7 +617,6 @@ def extract_state(c_in, v_len): q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) # Reshape for multi-head processing - batch, seq_len, _ = hidden_states.shape # query shape: (B, S, H_k, D_k) query = q_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) # key shape: (B, S, H_k, D_k) @@ -623,21 +647,8 @@ def extract_state(c_in, v_len): query = jnp.repeat(query, repeats, axis=2) # key shape after repeat: (B, S, H_v, D_k) key = jnp.repeat(key, repeats, axis=2) - elif self.num_k_heads > self.num_v_heads and self.num_k_heads % self.num_v_heads == 0: - pass - recurrent_state = None - if model_mode != MODEL_MODE_TRAIN: - # Retrieve state from self.cache - recurrent_state = self.cache.recurrent_state[...] - - if recurrent_state.shape[0] != batch: - if recurrent_state.shape[0] == 1: - recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) - else: - recurrent_state = recurrent_state[:batch] - - core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule( + core_attn_out, next_recurrent_state = jax_chunk_gated_delta_rule( query, key, value, @@ -649,9 +660,8 @@ def extract_state(c_in, v_len): compute_dtype=cfg.dtype, ) - if model_mode != MODEL_MODE_TRAIN: - # Update self.cache in place for both prefill and decode - self.cache.recurrent_state.set_value(recurrent_state_out) + if model_mode != MODEL_MODE_TRAIN and active_cache is not None: + active_cache.update_gdn_states(next_recurrent_state, next_conv_state) # ========================================================================= # STEP D: Final Output Stage @@ -669,7 +679,31 @@ def extract_state(c_in, v_len): # Final output shape: (B, S, E) output = self.out_proj(gated_output) - return output + return output, active_cache + + def init_kv_caches(self, batch_size: int): + """Initializes KVCache dynamically using the traced runtime batch size.""" + cfg = self.config + conv_dim = self.key_dim * 2 + self.value_dim + conv_kernel_size = cfg.gdn_conv_kernel_dim + + return kvcache.KVCache( + max_prefill_length=cfg.max_prefill_predict_length, + max_target_length=cfg.max_target_length, + batch=batch_size, # Injected directly from the JAX tracer + key_seq_len=1, + value_seq_len=1, + key_heads=self.num_v_heads, + value_heads=self.num_v_heads, + key_head_size=self.head_k_dim, + value_head_size=self.head_v_dim, + dtype=self.dtype, + is_gdn=True, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, + model_mode=self.model_mode, + rngs=self.rngs, + ) class Qwen3NextFullAttention(nnx.Module): @@ -1034,13 +1068,12 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3NextGatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..b0de507860 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -207,13 +207,12 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3_5GatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3_5GatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index db41a116f2..9b92eb642f 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -77,6 +77,8 @@ "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "qwen3.5-397b-a17b": "Qwen/Qwen3.5-397B-A17B", + "qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", "olmo3-7b": "allenai/Olmo-3-7B-Instruct",