From 2826e908688f5d0b3def45441a3bdb90e0a20eab Mon Sep 17 00:00:00 2001 From: learning-to-play <66660475+learning-to-play@users.noreply.github.com> Date: Sat, 16 May 2026 18:34:27 -0700 Subject: [PATCH 1/2] Gemma 2 Local/Global KV Cache Split Patch to support list-structured local sliding window and global KV caches, routing them correctly to the local/global attention blocks. Without this patch, Gemma 2 JAX serving fails immediately at runtime on the first execution step. --- src/maxtext/models/gemma2.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/maxtext/models/gemma2.py b/src/maxtext/models/gemma2.py index a7315763eb..d75d423c25 100644 --- a/src/maxtext/models/gemma2.py +++ b/src/maxtext/models/gemma2.py @@ -235,6 +235,19 @@ def __call__( lnx = self.pre_self_attention_norm_local(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + local_kv_cache = None + global_kv_cache = None + if isinstance(kv_cache, (list, tuple)): + if len(kv_cache) == 2: + local_kv_cache = kv_cache[0] + global_kv_cache = kv_cache[1] + else: + local_kv_cache = kv_cache[0] + global_kv_cache = kv_cache[0] + else: + local_kv_cache = kv_cache + global_kv_cache = kv_cache + attention_lnx, kv_cache = self.self_attention_local( lnx, lnx, @@ -282,6 +295,8 @@ def __call__( decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=global_kv_cache, + attention_metadata=attention_metadata, ) if self.config.use_post_attn_norm: attention_lnx = self.post_self_attention_norm_global(attention_lnx) @@ -315,6 +330,11 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) + if isinstance(kv_cache, (list, tuple)) and len(kv_cache) == 2: + returned_cache = [kv_cache_local, kv_cache_global] + else: + returned_cache = kv_cache_local + if self.config.scan_layers: return layer_output, None else: From 138c2fa6d9d0adeff6d393ceb747eb1afc782c91 Mon Sep 17 00:00:00 2001 From: learning-to-play <66660475+learning-to-play@users.noreply.github.com> Date: Sat, 16 May 2026 18:48:35 -0700 Subject: [PATCH 2/2] Gemma 2 Local/Global KV Cache Split --- src/maxtext/models/gemma2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxtext/models/gemma2.py b/src/maxtext/models/gemma2.py index d75d423c25..f2c0814c82 100644 --- a/src/maxtext/models/gemma2.py +++ b/src/maxtext/models/gemma2.py @@ -248,14 +248,14 @@ def __call__( local_kv_cache = kv_cache global_kv_cache = kv_cache - attention_lnx, kv_cache = self.self_attention_local( + attention_lnx, kv_cache_local = self.self_attention_local( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, - kv_cache=kv_cache, + kv_cache=local_kv_cache, attention_metadata=attention_metadata, ) if self.config.use_post_attn_norm: @@ -288,7 +288,7 @@ def __call__( lnx = self.pre_self_attention_norm_global(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx, kv_cache = self.self_attention_global( + attention_lnx, kv_cache_global = self.self_attention_global( lnx, lnx, decoder_positions, @@ -338,7 +338,7 @@ def __call__( if self.config.scan_layers: return layer_output, None else: - return layer_output, kv_cache + return layer_output, returned_cache Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(