Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/maxtext/configs/models/qwen3.5-35b-a3b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# maxtext/configs/models/qwen3.5-35b-moe.yml

decoder_block: "qwen3_5"

# Core Architectural Parameters
base_emb_dim: 2048
base_num_decoder_layers: 40
base_num_query_heads: 16
base_num_kv_heads: 2
head_dim: 256
vocab_size: 248320
normalization_layer_epsilon: 1.0e-6

# MoE Specific Parameters
# Set base_mlp_dim to match base_moe_mlp_dim to pass validation for fully MoE models.
base_mlp_dim: 512
base_moe_mlp_dim: 512
num_experts: 256
shared_experts: 1
num_experts_per_tok: 8
norm_topk_prob: True

# GatedDeltaNet Specific Parameters for Linear Attention (GDN)
inhomogeneous_layer_cycle_interval: 4
gdn_conv_kernel_dim: 4
gdn_key_head_dim: 128
gdn_value_head_dim: 128
gdn_num_key_heads: 16
gdn_num_value_heads: 32
gdn_chunk_size: 64

# RoPE Settings
rope_max_timescale: 10000000
partial_rotary_factor: 0.25

# General Model Settings
enable_dropout: False
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class ProfilerType(str, Enum):
"qwen3-omni-30b-a3b",
"qwen3-custom-30b-a3b",
"qwen3.5-397b-a17b",
"qwen3.5-35b-a3b",
"gpt3-175b",
"gpt3-22b",
"gpt3-6b",
Expand Down
145 changes: 75 additions & 70 deletions src/maxtext/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def kv_cache_as_linen(
key_axis_order: AxisIdxes = (2, 0, 1, 3),
use_chunked_prefill: bool = False,
model_mode: str = MODEL_MODE_PREFILL,
is_gdn: bool = False,
conv_kernel_size: int = 0,
conv_dim: int = 0,
name: str | None = None,
):
"""Initializes the KVCache module and returns it as a Linen module.
Expand Down Expand Up @@ -224,6 +227,9 @@ def kv_cache_as_linen(
key_axis_order=key_axis_order,
use_chunked_prefill=use_chunked_prefill,
model_mode=model_mode,
is_gdn=is_gdn,
conv_kernel_size=conv_kernel_size,
conv_dim=conv_dim,
metadata_fn=variable_to_logically_partitioned,
name=name,
abstract_init=False,
Expand Down Expand Up @@ -265,6 +271,9 @@ def __init__(
key_axis_order: AxisIdxes = (2, 0, 1, 3),
use_chunked_prefill: bool = False,
model_mode: str = MODEL_MODE_PREFILL,
is_gdn: bool = False, # <-- ADDED
conv_kernel_size: int = 0, # <-- ADDED
conv_dim: int = 0,
*,
# Not used in KVCache but passed in by nnx_wrappers.to_linen.
# TODO: Remove when bridge no longer needed
Expand Down Expand Up @@ -314,6 +323,9 @@ def __init__(
self.key_axis_order = key_axis_order
self.model_mode = model_mode
self.use_chunked_prefill = use_chunked_prefill
self.is_gdn = is_gdn # <-- ADDED
self.conv_kernel_size = conv_kernel_size # <-- ADDED
self.conv_dim = conv_dim

if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
self._initialize_prefill_caches(model_mode)
Expand Down Expand Up @@ -349,6 +361,14 @@ def _get_cache_scale_logical_shape(self, heads, cache_length):
def _initialize_prefill_caches(self, model_mode):
"""Get a shaped abstraction of the state"""

if self.is_gdn:
self.cached_prefill_key = None
self.cached_prefill_value = None
self.cache_prefill_segment_id = None
self.cached_prefill_key_scale = None
self.cached_prefill_value_scale = None
return

cache_length = self.max_prefill_length
dtype = self._get_cached_kv_dtype()

Expand Down Expand Up @@ -411,6 +431,45 @@ def _initialize_ar_cache_vars(self, model_mode):
"""get ar cache vars"""

dtype = self._get_cached_kv_dtype()

# Pre-allocate fixed-size GDN states to standard cache containers
if self.is_gdn:
cache_batch_axis_name = CACHE_BATCH_PREFILL if model_mode == MODEL_MODE_PREFILL else CACHE_BATCH

# 1. Map Recurrent State matrix directly to cached_ar_key
# Shape: [batch, key_heads, key_head_size, value_head_size]
self.cached_ar_key = nnx.Cache(
jnp.zeros((self.batch, self.key_heads, self.key_head_size, self.value_head_size), dtype=dtype),
out_sharding=(cache_batch_axis_name, CACHE_HEADS, None, None),
)

# 2. Map 1D Conv State directly to cached_ar_value
# Shape: [batch, conv_kernel_size - 1, conv_dim]
self.cached_ar_value = nnx.Cache(
jnp.zeros((self.batch, self.conv_kernel_size - 1, self.conv_dim), dtype=dtype),
out_sharding=(cache_batch_axis_name, None, None),
)

# Initialize required dummy variables to satisfy uniform engine inspection loops
segment_id_axis_names = (
(CACHE_BATCH_PREFILL, CACHE_SEQUENCE) if model_mode == MODEL_MODE_PREFILL else (CACHE_BATCH, CACHE_SEQUENCE)
)
self.cache_ar_segment_id = nnx.Cache(
jnp.zeros((self.batch, 1), dtype=jnp.int32),
out_sharding=segment_id_axis_names,
)
self.cached_ar_lengths = nnx.Cache(
jnp.zeros((self.batch,), dtype=jnp.int32),
out_sharding=(CACHE_BATCH,),
)
self.cached_ar_key_scale = None
self.cached_ar_value_scale = None
self.cache_ar_index = nnx.Cache(
jnp.zeros((1,), dtype=jnp.int32),
out_sharding=(),
)
return

if self.max_target_length <= self.max_prefill_length:
raise ValueError(
f"max_target_length: {self.max_target_length} should be greater than max_prefill_length:"
Expand Down Expand Up @@ -489,6 +548,22 @@ def _initialize_ar_cache_vars(self, model_mode):
out_sharding=(),
)

def get_gdn_states(self) -> tuple[jax.Array, jax.Array]:
"""Retrieves the recurrent state and conv state for GDN layers."""
assert self.is_gdn, "get_gdn_states called on a non-GDN cache object."
return self.cached_ar_key.get_value(), self.cached_ar_value.get_value()

def update_gdn_states(self, new_recurrent_state: Array, new_conv_state: Array) -> None:
"""Updates the recurrent state and conv state for GDN layers."""
assert self.is_gdn, "update_gdn_states called on a non-GDN cache object."
cache_batch_axis_name = CACHE_BATCH_PREFILL if self.model_mode == MODEL_MODE_PREFILL else CACHE_BATCH

# Overwrite running states while enforcing designated logical partitioning rules
self.cached_ar_key.set_value(
nn.with_logical_constraint(new_recurrent_state, (cache_batch_axis_name, CACHE_HEADS, None, None))
)
self.cached_ar_value.set_value(nn.with_logical_constraint(new_conv_state, (cache_batch_axis_name, None, None)))

def _get_ar_cache_vars(self):
return self.ar_key_vars, self.ar_value_vars, self.cache_ar_segment_id, self.cache_ar_index, self.cached_ar_lengths

Expand Down Expand Up @@ -880,76 +955,6 @@ def __call__(
raise ValueError(f"Model Mode isn't supported! {model_mode=}")


class GatedDeltaNetCache(BaseCache):
"""Cache for Linear Attention (Gated Delta Net).

Stores the fixed-size recurrent state and the sliding window state for convolution.
"""

def __init__(
self,
batch: int,
num_heads: int,
k_head_dim: int,
v_head_dim: int,
conv_kernel_size: int,
conv_dim: int,
dtype: DType,
cache_batch_axis_name: str = CACHE_BATCH,
cache_heads_axis_name: str = CACHE_HEADS,
):
super().__init__()
self.batch = batch
self.dtype = dtype

# 1. Recurrent State (S) for the Delta Rule
# Shape: [Batch, Heads, K_Dim, V_Dim]
# We maintain the running state matrix.
self.recurrent_state = nnx.Cache(
jnp.zeros((int(batch), num_heads, k_head_dim, v_head_dim), dtype=dtype),
# Sharding: Batch, Heads, None (K), None (V)
out_sharding=(cache_batch_axis_name, cache_heads_axis_name, None, None),
)

# 2. Convolution State for the 1D Conv
# Shape: [Batch, Kernel_Size - 1, Conv_Dim]
# We store the last (K-1) inputs to perform the sliding window conv during decoding.
self.conv_state = nnx.Cache(
jnp.zeros((int(batch), conv_kernel_size - 1, conv_dim), dtype=dtype),
# Sharding: Batch, None (Time), None (Dim)
out_sharding=(cache_batch_axis_name, None, None),
)

def __call__(self):
"""Returns the cache variables for the layer to use."""
return self


def gated_delta_net_cache_as_linen(
*,
batch: int,
num_heads: int,
head_dim: int,
conv_kernel_size: int,
conv_dim: int,
dtype: DType,
name: str | None = None,
):
"""Initializes the GatedDeltaNetCache and returns it as a Linen module."""
return nnx_wrappers.to_linen(
GatedDeltaNetCache,
batch=batch,
num_heads=num_heads,
head_dim=head_dim,
conv_kernel_size=conv_kernel_size,
conv_dim=conv_dim,
dtype=dtype,
metadata_fn=variable_to_logically_partitioned,
name=name,
abstract_init=False,
)


def mla_kv_cache_as_linen(
*,
max_prefill_length: int,
Expand Down
12 changes: 2 additions & 10 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,8 @@ def __call__(
if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
layer_kwargs = {"layer_idx": lyr}
kv_cache = None
if kv_caches is not None and cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
if kv_caches is not None:
kv_cache = kv_caches[lyr]
elif kv_caches is not None and cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
# For Qwen3Next & Qwen3.5, kv_caches is a dictionary of lists of caches.
if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr])

if cfg.decoder_block == DecoderBlockType.GPT_OSS:
layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)}
Expand All @@ -1135,11 +1131,7 @@ def __call__(
**layer_call_kwargs,
)
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_caches[lyr] = returned_cache
elif (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0:
kv_caches["key_cache"][lyr] = returned_cache[0]
kv_caches["value_cache"][lyr] = returned_cache[1]
kv_caches[lyr] = returned_cache

if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
visual_embeds = deepstack_visual_embeds[lyr]
Expand Down
Loading
Loading