diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index 2736b8aafb..c67e68f9e1 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -35,6 +35,7 @@ from maxtext.layers.attentions import AttentionOp, KVQuant from maxtext.layers.initializers import Initializer, NdInitializer, nd_dense_init from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.inference import kvcache from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -236,6 +237,11 @@ def __init__( self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names self.rngs = rngs + self.model_mode = model_mode + self.prefill_cache_axis_order = (1, 2, 0, 3) + self.ar_cache_axis_order = (1, 2, 0, 3) + self.use_ragged_attention = False + self.KVCache_0 = self.init_kv_caches(inputs_kv_shape=feature_dim) if self.model_mode != MODEL_MODE_TRAIN else None if self.fused_qkv: self.qkv_proj = self.create_projection_layer( feature_dim, (3, self.num_heads, self.head_dim), ("embed", "qkv", "heads", "kv") @@ -299,6 +305,40 @@ def projection(self, projection_layer: Any, inputs: Array) -> Array: proj = projection_layer(inputs) return proj + def init_kv_caches(self, inputs_kv_shape: tuple[int, ...]): + batch_size, _, _ = inputs_kv_shape + placeholder_seq_len = 1 + + return kvcache.KVCache( + max_prefill_length=self.max_prefill_predict_length, + max_target_length=self.max_target_length, + batch=batch_size, + key_seq_len=placeholder_seq_len, + value_seq_len=placeholder_seq_len, + key_heads=self.num_heads, + value_heads=self.num_heads, + key_head_size=self.head_dim, + value_head_size=self.head_dim, + dtype=self.dtype, + kv_quant=self.kv_quant, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + use_chunked_prefill=self.config.use_chunked_prefill, + model_mode=self.model_mode, + rngs=self.rngs, + ) + + def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous_chunk): + prefill_kv_cache, ar_kv_cache = self.KVCache_0( + key=key, + value=value, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + use_ragged_attention=self.use_ragged_attention, + previous_chunk=previous_chunk, + ) + return [prefill_kv_cache, ar_kv_cache] + def __call__( self, inputs_q: Array, @@ -306,6 +346,7 @@ def __call__( *, deterministic: bool = False, model_mode: str = MODEL_MODE_TRAIN, + previous_chunk: Any = None, kv_cache: Array | None = None, attention_metadata: dict[str, Any] | None = None, ): @@ -328,7 +369,11 @@ def __call__( value = nn.with_logical_constraint(value, self.value_axis_names) value = checkpoint_name(value, "value_proj") - out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode) + cached_values = [None, None] + if model_mode != MODEL_MODE_TRAIN: + cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) + + out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode, cached_values) out = nn.with_logical_constraint(out, self.out_axis_names) @@ -448,6 +493,7 @@ def __call__( decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic, + previous_chunk=previous_chunk, kv_cache=kv_cache, attention_metadata=attention_metadata, ) diff --git a/tests/unit/gpt3_test.py b/tests/unit/gpt3_test.py index 952f148183..f8941292ce 100644 --- a/tests/unit/gpt3_test.py +++ b/tests/unit/gpt3_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Tests for GPT3. """ +"""Tests for GPT3.""" import sys import unittest @@ -21,11 +21,12 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import maxtext_utils from tests.utils.test_helpers import get_test_config_path +import numpy as np import pytest @@ -61,10 +62,11 @@ def setUp(self): enable_checkpointing=False, model_name="gpt3-52k", dtype="float32", + per_device_batch_size=1.0 / jax.device_count(), ) self.rng = jax.random.PRNGKey(1234) - devices_array = maxtext_utils.create_device_mesh(self.cfg) + devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]]) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) @@ -106,3 +108,76 @@ def test_logits_numerically(self): jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03), msg=f"per_example_xent:\n{per_example_xent}\n\nper_example_xent_truth:\n{per_example_xent_truth}", ) + + @pytest.mark.tpu_only + def test_prefill_and_autoregress(self): + """Verifies that GPT-3 attention correctly initializes and updates the KV cache during decoding.""" + PREFILL_RANGE = 2 + devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]]) + mesh = Mesh(devices_array, self.cfg.mesh_axes) + quant = quantizations.configure_quantization(self.cfg) + prefill_model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + + # 0. Compute baseline full training logits for numerical equivalence comparison. + full_train_logits, _ = self.model.apply( + self.model_vars, + self.example_batch["inputs"], + self.example_batch["inputs_position"], + decoder_segment_ids=self.example_batch["inputs_segmentation"], + enable_dropout=False, + rngs={"dropout": self.rng, "aqt": self.rng}, + mutable="intermediates", + ) + + # 1. Initialize model variables and KV cache structures in prefill mode. + prefill_transformer_vars = prefill_model.init( + {"params": self.rng, "aqt": self.rng}, + self.example_batch["inputs"], + self.example_batch["inputs_position"], + model_mode=MODEL_MODE_PREFILL, + decoder_segment_ids=self.example_batch["inputs_segmentation"], + enable_dropout=False, + ) + # Replace zero initializers with normal distribution to ensure strong numerical test cases + prefill_transformer_vars = init_random_model_vars(prefill_model, self.rng, self.example_batch) + + # 2. Execute a partial prefill pass to populate the KV cache. + partial_prefill_logits, partial_cache = prefill_model.apply( + prefill_transformer_vars, + self.example_batch["inputs"][:, :PREFILL_RANGE], + self.example_batch["inputs_position"][:, :PREFILL_RANGE], + model_mode=MODEL_MODE_PREFILL, + decoder_segment_ids=self.example_batch["inputs_segmentation"][:, :PREFILL_RANGE], + enable_dropout=False, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + # Verify partial prefill exactly matches full training logits + np.testing.assert_allclose( + full_train_logits[:, :PREFILL_RANGE, :], + partial_prefill_logits, + rtol=1e-01, + atol=1e-01, + ) + + # 3. Perform an autoregressive decoding step using the updated KV cache. + idx = PREFILL_RANGE + ids_idx = self.example_batch["inputs"][:, idx : idx + 1] + decoder_positions_idx = self.example_batch["inputs_position"][:, idx : idx + 1] + prefill_transformer_vars.update(partial_cache) + ar_logits, _ = prefill_model.apply( + prefill_transformer_vars, + ids_idx, + decoder_positions_idx, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + enable_dropout=False, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + # Verify autoregressive decoding exactly matches full training logits at the decoded position + np.testing.assert_allclose( + full_train_logits[:, idx : idx + 1, :], + ar_logits, + rtol=1e-01, + atol=1e-01, + )