diff --git a/src/maxtext/models/gemma2.py b/src/maxtext/models/gemma2.py index a7315763eb..f2c0814c82 100644 --- a/src/maxtext/models/gemma2.py +++ b/src/maxtext/models/gemma2.py @@ -235,14 +235,27 @@ def __call__( lnx = self.pre_self_attention_norm_local(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx, kv_cache = self.self_attention_local( + local_kv_cache = None + global_kv_cache = None + if isinstance(kv_cache, (list, tuple)): + if len(kv_cache) == 2: + local_kv_cache = kv_cache[0] + global_kv_cache = kv_cache[1] + else: + local_kv_cache = kv_cache[0] + global_kv_cache = kv_cache[0] + else: + local_kv_cache = kv_cache + global_kv_cache = kv_cache + + attention_lnx, kv_cache_local = self.self_attention_local( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, - kv_cache=kv_cache, + kv_cache=local_kv_cache, attention_metadata=attention_metadata, ) if self.config.use_post_attn_norm: @@ -275,13 +288,15 @@ def __call__( lnx = self.pre_self_attention_norm_global(inputs) lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) - attention_lnx, kv_cache = self.self_attention_global( + attention_lnx, kv_cache_global = self.self_attention_global( lnx, lnx, decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + kv_cache=global_kv_cache, + attention_metadata=attention_metadata, ) if self.config.use_post_attn_norm: attention_lnx = self.post_self_attention_norm_global(attention_lnx) @@ -315,10 +330,15 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) + if isinstance(kv_cache, (list, tuple)) and len(kv_cache) == 2: + returned_cache = [kv_cache_local, kv_cache_global] + else: + returned_cache = kv_cache_local + if self.config.scan_layers: return layer_output, None else: - return layer_output, kv_cache + return layer_output, returned_cache Gemma2DecoderLayerToLinen = nnx_wrappers.to_linen_class(