From 573398bf25d7bd80133932e3d75ef01948eb33c6 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Wed, 13 May 2026 20:43:43 +0000 Subject: [PATCH] Try new caching strategy for gdn layer Add mini config model support for q3.5 Wrong config name updated Remove special casing for caching since using existing kvcache class return kvcache instead of active_cache Use kvcache class and remove extra logic in decoders.py Add logic for proper batching of gdn caches Update for nnx issue when batch size > 1 Remove GDN specific cache Fixed linter issues Run linter on qwen3.py --- .../configs/models/qwen3.5-35b-a3b.yml | 51 ++++++ src/maxtext/configs/types.py | 1 + src/maxtext/inference/kvcache.py | 145 +++++++++--------- src/maxtext/layers/decoders.py | 12 +- src/maxtext/models/qwen3.py | 115 +++++++++----- src/maxtext/models/qwen3_5.py | 5 +- src/maxtext/utils/globals.py | 2 + 7 files changed, 207 insertions(+), 124 deletions(-) create mode 100644 src/maxtext/configs/models/qwen3.5-35b-a3b.yml diff --git a/src/maxtext/configs/models/qwen3.5-35b-a3b.yml b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml new file mode 100644 index 0000000000..44461b8040 --- /dev/null +++ b/src/maxtext/configs/models/qwen3.5-35b-a3b.yml @@ -0,0 +1,51 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# maxtext/configs/models/qwen3.5-35b-moe.yml + +decoder_block: "qwen3_5" + +# Core Architectural Parameters +base_emb_dim: 2048 +base_num_decoder_layers: 40 +base_num_query_heads: 16 +base_num_kv_heads: 2 +head_dim: 256 +vocab_size: 248320 +normalization_layer_epsilon: 1.0e-6 + +# MoE Specific Parameters +# Set base_mlp_dim to match base_moe_mlp_dim to pass validation for fully MoE models. +base_mlp_dim: 512 +base_moe_mlp_dim: 512 +num_experts: 256 +shared_experts: 1 +num_experts_per_tok: 8 +norm_topk_prob: True + +# GatedDeltaNet Specific Parameters for Linear Attention (GDN) +inhomogeneous_layer_cycle_interval: 4 +gdn_conv_kernel_dim: 4 +gdn_key_head_dim: 128 +gdn_value_head_dim: 128 +gdn_num_key_heads: 16 +gdn_num_value_heads: 32 +gdn_chunk_size: 64 + +# RoPE Settings +rope_max_timescale: 10000000 +partial_rotary_factor: 0.25 + +# General Model Settings +enable_dropout: False \ No newline at end of file diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..a2f8af4f8b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -260,6 +260,7 @@ class ProfilerType(str, Enum): "qwen3-omni-30b-a3b", "qwen3-custom-30b-a3b", "qwen3.5-397b-a17b", + "qwen3.5-35b-a3b", "gpt3-175b", "gpt3-22b", "gpt3-6b", diff --git a/src/maxtext/inference/kvcache.py b/src/maxtext/inference/kvcache.py index 2162a9d9eb..515ffa7286 100644 --- a/src/maxtext/inference/kvcache.py +++ b/src/maxtext/inference/kvcache.py @@ -174,6 +174,9 @@ def kv_cache_as_linen( key_axis_order: AxisIdxes = (2, 0, 1, 3), use_chunked_prefill: bool = False, model_mode: str = MODEL_MODE_PREFILL, + is_gdn: bool = False, + conv_kernel_size: int = 0, + conv_dim: int = 0, name: str | None = None, ): """Initializes the KVCache module and returns it as a Linen module. @@ -224,6 +227,9 @@ def kv_cache_as_linen( key_axis_order=key_axis_order, use_chunked_prefill=use_chunked_prefill, model_mode=model_mode, + is_gdn=is_gdn, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, metadata_fn=variable_to_logically_partitioned, name=name, abstract_init=False, @@ -265,6 +271,9 @@ def __init__( key_axis_order: AxisIdxes = (2, 0, 1, 3), use_chunked_prefill: bool = False, model_mode: str = MODEL_MODE_PREFILL, + is_gdn: bool = False, # <-- ADDED + conv_kernel_size: int = 0, # <-- ADDED + conv_dim: int = 0, *, # Not used in KVCache but passed in by nnx_wrappers.to_linen. # TODO: Remove when bridge no longer needed @@ -314,6 +323,9 @@ def __init__( self.key_axis_order = key_axis_order self.model_mode = model_mode self.use_chunked_prefill = use_chunked_prefill + self.is_gdn = is_gdn # <-- ADDED + self.conv_kernel_size = conv_kernel_size # <-- ADDED + self.conv_dim = conv_dim if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): self._initialize_prefill_caches(model_mode) @@ -349,6 +361,14 @@ def _get_cache_scale_logical_shape(self, heads, cache_length): def _initialize_prefill_caches(self, model_mode): """Get a shaped abstraction of the state""" + if self.is_gdn: + self.cached_prefill_key = None + self.cached_prefill_value = None + self.cache_prefill_segment_id = None + self.cached_prefill_key_scale = None + self.cached_prefill_value_scale = None + return + cache_length = self.max_prefill_length dtype = self._get_cached_kv_dtype() @@ -411,6 +431,45 @@ def _initialize_ar_cache_vars(self, model_mode): """get ar cache vars""" dtype = self._get_cached_kv_dtype() + + # Pre-allocate fixed-size GDN states to standard cache containers + if self.is_gdn: + cache_batch_axis_name = CACHE_BATCH_PREFILL if model_mode == MODEL_MODE_PREFILL else CACHE_BATCH + + # 1. Map Recurrent State matrix directly to cached_ar_key + # Shape: [batch, key_heads, key_head_size, value_head_size] + self.cached_ar_key = nnx.Cache( + jnp.zeros((self.batch, self.key_heads, self.key_head_size, self.value_head_size), dtype=dtype), + out_sharding=(cache_batch_axis_name, CACHE_HEADS, None, None), + ) + + # 2. Map 1D Conv State directly to cached_ar_value + # Shape: [batch, conv_kernel_size - 1, conv_dim] + self.cached_ar_value = nnx.Cache( + jnp.zeros((self.batch, self.conv_kernel_size - 1, self.conv_dim), dtype=dtype), + out_sharding=(cache_batch_axis_name, None, None), + ) + + # Initialize required dummy variables to satisfy uniform engine inspection loops + segment_id_axis_names = ( + (CACHE_BATCH_PREFILL, CACHE_SEQUENCE) if model_mode == MODEL_MODE_PREFILL else (CACHE_BATCH, CACHE_SEQUENCE) + ) + self.cache_ar_segment_id = nnx.Cache( + jnp.zeros((self.batch, 1), dtype=jnp.int32), + out_sharding=segment_id_axis_names, + ) + self.cached_ar_lengths = nnx.Cache( + jnp.zeros((self.batch,), dtype=jnp.int32), + out_sharding=(CACHE_BATCH,), + ) + self.cached_ar_key_scale = None + self.cached_ar_value_scale = None + self.cache_ar_index = nnx.Cache( + jnp.zeros((1,), dtype=jnp.int32), + out_sharding=(), + ) + return + if self.max_target_length <= self.max_prefill_length: raise ValueError( f"max_target_length: {self.max_target_length} should be greater than max_prefill_length:" @@ -489,6 +548,22 @@ def _initialize_ar_cache_vars(self, model_mode): out_sharding=(), ) + def get_gdn_states(self) -> tuple[jax.Array, jax.Array]: + """Retrieves the recurrent state and conv state for GDN layers.""" + assert self.is_gdn, "get_gdn_states called on a non-GDN cache object." + return self.cached_ar_key.get_value(), self.cached_ar_value.get_value() + + def update_gdn_states(self, new_recurrent_state: Array, new_conv_state: Array) -> None: + """Updates the recurrent state and conv state for GDN layers.""" + assert self.is_gdn, "update_gdn_states called on a non-GDN cache object." + cache_batch_axis_name = CACHE_BATCH_PREFILL if self.model_mode == MODEL_MODE_PREFILL else CACHE_BATCH + + # Overwrite running states while enforcing designated logical partitioning rules + self.cached_ar_key.set_value( + nn.with_logical_constraint(new_recurrent_state, (cache_batch_axis_name, CACHE_HEADS, None, None)) + ) + self.cached_ar_value.set_value(nn.with_logical_constraint(new_conv_state, (cache_batch_axis_name, None, None))) + def _get_ar_cache_vars(self): return self.ar_key_vars, self.ar_value_vars, self.cache_ar_segment_id, self.cache_ar_index, self.cached_ar_lengths @@ -880,76 +955,6 @@ def __call__( raise ValueError(f"Model Mode isn't supported! {model_mode=}") -class GatedDeltaNetCache(BaseCache): - """Cache for Linear Attention (Gated Delta Net). - - Stores the fixed-size recurrent state and the sliding window state for convolution. - """ - - def __init__( - self, - batch: int, - num_heads: int, - k_head_dim: int, - v_head_dim: int, - conv_kernel_size: int, - conv_dim: int, - dtype: DType, - cache_batch_axis_name: str = CACHE_BATCH, - cache_heads_axis_name: str = CACHE_HEADS, - ): - super().__init__() - self.batch = batch - self.dtype = dtype - - # 1. Recurrent State (S) for the Delta Rule - # Shape: [Batch, Heads, K_Dim, V_Dim] - # We maintain the running state matrix. - self.recurrent_state = nnx.Cache( - jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype), - # Sharding: Batch, Heads, None (K), None (V) - out_sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None), - ) - - # 2. Convolution State for the 1D Conv - # Shape: [Batch, Kernel_Size - 1, Conv_Dim] - # We store the last (K-1) inputs to perform the sliding window conv during decoding. - self.conv_state = nnx.Cache( - jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype), - # Sharding: Batch, None (Time), None (Dim) - out_sharding=(cache_batch_axis_name, None, None), - ) - - def __call__(self): - """Returns the cache variables for the layer to use.""" - return self - - -def gated_delta_net_cache_as_linen( - *, - batch: int, - num_heads: int, - head_dim: int, - conv_kernel_size: int, - conv_dim: int, - dtype: DType, - name: str | None = None, -): - """Initializes the GatedDeltaNetCache and returns it as a Linen module.""" - return nnx_wrappers.to_linen( - GatedDeltaNetCache, - batch=batch, - num_heads=num_heads, - head_dim=head_dim, - conv_kernel_size=conv_kernel_size, - conv_dim=conv_dim, - dtype=dtype, - metadata_fn=variable_to_logically_partitioned, - name=name, - abstract_init=False, - ) - - def mla_kv_cache_as_linen( *, max_prefill_length: int, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index a2d52dd033..584a47eab1 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1107,12 +1107,8 @@ def __call__( if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): layer_kwargs = {"layer_idx": lyr} kv_cache = None - if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if kv_caches is not None: kv_cache = kv_caches[lyr] - elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - # For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches. - if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) if cfg.decoder_block == DecoderBlockType.GPT_OSS: layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} @@ -1135,11 +1131,7 @@ def __call__( **layer_call_kwargs, ) if kv_caches is not None and returned_cache is not None: - if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): - kv_caches[lyr] = returned_cache - elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: - kv_caches["key_cache"][lyr] = returned_cache[0] - kv_caches["value_cache"][lyr] = returned_cache[1] + kv_caches[lyr] = returned_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..46566f2809 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -402,15 +402,28 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads if model_mode != MODEL_MODE_TRAIN: - self.cache = kvcache.GatedDeltaNetCache( - batch=config.per_device_batch_size, - num_heads=self.num_v_heads, - k_head_dim=self.head_k_dim, - v_head_dim=self.head_v_dim, - conv_kernel_size=self.config.gdn_conv_kernel_dim, - conv_dim=conv_dim, + # Use global batch size so cache allocation matches SPMD tracing shapes + global_batch, _ = max_utils.get_batch_seq_len_for_mode(config, model_mode) + + self.cache = kvcache.KVCache( + max_prefill_length=cfg.max_prefill_predict_length, + max_target_length=cfg.max_target_length, + batch=global_batch, # Directly matches the global trace of 8 + key_seq_len=1, + value_seq_len=1, + key_heads=self.num_v_heads, + value_heads=self.num_v_heads, + key_head_size=self.head_k_dim, + value_head_size=self.head_v_dim, dtype=dtype, + is_gdn=True, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, + model_mode=model_mode, + rngs=rngs, ) + else: + self.cache = None # Submodule instantiations self.in_proj_qkvz = DenseGeneral( @@ -483,6 +496,8 @@ def __call__( cfg = self.config batch, seq_len, _ = hidden_states.shape + active_cache = kv_cache if kv_cache is not None else self.cache + # ========================================================================= # STEP A: Input Projections # ========================================================================= @@ -554,33 +569,43 @@ def __call__( conv_kernel_size = self.config.gdn_conv_kernel_dim conv_state = None - if model_mode != MODEL_MODE_TRAIN: - # Retrieve state from self.cache - conv_state = self.cache.conv_state[...] + recurrent_state = None + next_conv_state = None + if model_mode != MODEL_MODE_TRAIN and active_cache is not None: + recurrent_state, conv_state = active_cache.get_gdn_states() + + # Safely align conv_state batch dimension (Broadcast, Slice, or Pad) if conv_state.shape[0] != batch: - # Assumes zero-initialized state for testing if conv_state.shape[0] == 1: conv_state = jnp.broadcast_to(conv_state, (batch,) + conv_state.shape[1:]) + elif conv_state.shape[0] < batch: + pad_amt = batch - conv_state.shape[0] + conv_state = jnp.pad(conv_state, ((0, pad_amt), (0, 0), (0, 0))) else: conv_state = conv_state[:batch] - # Concatenate previous state with new input + # Safely align recurrent_state batch dimension + if recurrent_state.shape[0] != batch: + if recurrent_state.shape[0] == 1: + recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) + elif recurrent_state.shape[0] < batch: + pad_amt = batch - recurrent_state.shape[0] + recurrent_state = jnp.pad(recurrent_state, ((0, pad_amt), (0, 0), (0, 0), (0, 0))) + else: + recurrent_state = recurrent_state[:batch] + conv_input = jnp.concatenate([conv_state, qkv], axis=1) if decoder_segment_ids is not None: - valid_lens = jnp.sum(decoder_segment_ids != 0, axis=1) # Shape: (B,) + valid_lens = jnp.sum(decoder_segment_ids != 0, axis=1) def extract_state(c_in, v_len): return jax.lax.dynamic_slice_in_dim(c_in, v_len, conv_kernel_size - 1, axis=0) - new_conv_state = jax.vmap(extract_state)(conv_input, valid_lens) + next_conv_state = jax.vmap(extract_state)(conv_input, valid_lens) else: - new_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] - - # Update self.cache in place - self.cache.conv_state.set_value(new_conv_state) + next_conv_state = conv_input[:, -(conv_kernel_size - 1) :, :] else: - # Train: pad with zeros conv_input = jnp.pad(qkv, ((0, 0), (conv_kernel_size - 1, 0), (0, 0))) # Perform the convolution. @@ -592,7 +617,6 @@ def extract_state(c_in, v_len): q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) # Reshape for multi-head processing - batch, seq_len, _ = hidden_states.shape # query shape: (B, S, H_k, D_k) query = q_conv.reshape(batch, seq_len, self.num_k_heads, self.head_k_dim) # key shape: (B, S, H_k, D_k) @@ -623,21 +647,8 @@ def extract_state(c_in, v_len): query = jnp.repeat(query, repeats, axis=2) # key shape after repeat: (B, S, H_v, D_k) key = jnp.repeat(key, repeats, axis=2) - elif self.num_k_heads > self.num_v_heads and self.num_k_heads % self.num_v_heads == 0: - pass - recurrent_state = None - if model_mode != MODEL_MODE_TRAIN: - # Retrieve state from self.cache - recurrent_state = self.cache.recurrent_state[...] - - if recurrent_state.shape[0] != batch: - if recurrent_state.shape[0] == 1: - recurrent_state = jnp.broadcast_to(recurrent_state, (batch,) + recurrent_state.shape[1:]) - else: - recurrent_state = recurrent_state[:batch] - - core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule( + core_attn_out, next_recurrent_state = jax_chunk_gated_delta_rule( query, key, value, @@ -649,9 +660,8 @@ def extract_state(c_in, v_len): compute_dtype=cfg.dtype, ) - if model_mode != MODEL_MODE_TRAIN: - # Update self.cache in place for both prefill and decode - self.cache.recurrent_state.set_value(recurrent_state_out) + if model_mode != MODEL_MODE_TRAIN and active_cache is not None: + active_cache.update_gdn_states(next_recurrent_state, next_conv_state) # ========================================================================= # STEP D: Final Output Stage @@ -669,7 +679,31 @@ def extract_state(c_in, v_len): # Final output shape: (B, S, E) output = self.out_proj(gated_output) - return output + return output, active_cache + + def init_kv_caches(self, batch_size: int): + """Initializes KVCache dynamically using the traced runtime batch size.""" + cfg = self.config + conv_dim = self.key_dim * 2 + self.value_dim + conv_kernel_size = cfg.gdn_conv_kernel_dim + + return kvcache.KVCache( + max_prefill_length=cfg.max_prefill_predict_length, + max_target_length=cfg.max_target_length, + batch=batch_size, # Injected directly from the JAX tracer + key_seq_len=1, + value_seq_len=1, + key_heads=self.num_v_heads, + value_heads=self.num_v_heads, + key_head_size=self.head_k_dim, + value_head_size=self.head_v_dim, + dtype=self.dtype, + is_gdn=True, + conv_kernel_size=conv_kernel_size, + conv_dim=conv_dim, + model_mode=self.model_mode, + rngs=self.rngs, + ) class Qwen3NextFullAttention(nnx.Module): @@ -1034,13 +1068,12 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3NextGatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3NextGatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..b0de507860 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -207,13 +207,12 @@ def __call__( attention_metadata=attention_metadata, ) else: - attention_output = cast(Qwen3_5GatedDeltaNet, self.attention)( + attention_output, new_kv_cache = cast(Qwen3_5GatedDeltaNet, self.attention)( hidden_states, model_mode=model_mode, - kv_cache=None, + kv_cache=kv_cache, decoder_segment_ids=decoder_segment_ids, ) - new_kv_cache = None # First residual connection after attention hidden_states = residual + attention_output diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index db41a116f2..9b92eb642f 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -77,6 +77,8 @@ "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", + "qwen3.5-397b-a17b": "Qwen/Qwen3.5-397B-A17B", + "qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", "olmo3-7b": "allenai/Olmo-3-7B-Instruct",