-
Notifications
You must be signed in to change notification settings - Fork 517
Fix GPT3 attention missing KV cache initialization and handling #3927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ Tests for GPT3. """ | ||
| """Tests for GPT3.""" | ||
|
|
||
| import sys | ||
| import unittest | ||
|
|
@@ -21,11 +21,12 @@ | |
| import jax.numpy as jnp | ||
| from jax.sharding import Mesh | ||
| from maxtext.configs import pyconfig | ||
| from maxtext.common.common_types import MODEL_MODE_TRAIN | ||
| from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE | ||
| from maxtext.layers import quantizations | ||
| from maxtext.models import models | ||
| from maxtext.utils import maxtext_utils | ||
| from tests.utils.test_helpers import get_test_config_path | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
|
|
||
|
|
@@ -61,10 +62,11 @@ def setUp(self): | |
| enable_checkpointing=False, | ||
| model_name="gpt3-52k", | ||
| dtype="float32", | ||
| per_device_batch_size=1.0 / jax.device_count(), | ||
| ) | ||
| self.rng = jax.random.PRNGKey(1234) | ||
|
|
||
| devices_array = maxtext_utils.create_device_mesh(self.cfg) | ||
| devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we only use one device for testing? No sharding involved?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for the KV cache test. Running everything on a single device eliminates communication and ensures the test is fully deterministic. |
||
| mesh = Mesh(devices_array, self.cfg.mesh_axes) | ||
| quant = quantizations.configure_quantization(self.cfg) | ||
| self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) | ||
|
|
@@ -106,3 +108,76 @@ def test_logits_numerically(self): | |
| jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03), | ||
| msg=f"per_example_xent:\n{per_example_xent}\n\nper_example_xent_truth:\n{per_example_xent_truth}", | ||
| ) | ||
|
|
||
| @pytest.mark.tpu_only | ||
| def test_prefill_and_autoregress(self): | ||
| """Verifies that GPT-3 attention correctly initializes and updates the KV cache during decoding.""" | ||
| PREFILL_RANGE = 2 | ||
| devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]]) | ||
| mesh = Mesh(devices_array, self.cfg.mesh_axes) | ||
| quant = quantizations.configure_quantization(self.cfg) | ||
| prefill_model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) | ||
|
|
||
| # 0. Compute baseline full training logits for numerical equivalence comparison. | ||
| full_train_logits, _ = self.model.apply( | ||
| self.model_vars, | ||
| self.example_batch["inputs"], | ||
| self.example_batch["inputs_position"], | ||
| decoder_segment_ids=self.example_batch["inputs_segmentation"], | ||
| enable_dropout=False, | ||
| rngs={"dropout": self.rng, "aqt": self.rng}, | ||
| mutable="intermediates", | ||
| ) | ||
|
|
||
| # 1. Initialize model variables and KV cache structures in prefill mode. | ||
| prefill_transformer_vars = prefill_model.init( | ||
| {"params": self.rng, "aqt": self.rng}, | ||
| self.example_batch["inputs"], | ||
| self.example_batch["inputs_position"], | ||
| model_mode=MODEL_MODE_PREFILL, | ||
| decoder_segment_ids=self.example_batch["inputs_segmentation"], | ||
| enable_dropout=False, | ||
| ) | ||
| # Replace zero initializers with normal distribution to ensure strong numerical test cases | ||
| prefill_transformer_vars = init_random_model_vars(prefill_model, self.rng, self.example_batch) | ||
|
|
||
| # 2. Execute a partial prefill pass to populate the KV cache. | ||
| partial_prefill_logits, partial_cache = prefill_model.apply( | ||
| prefill_transformer_vars, | ||
| self.example_batch["inputs"][:, :PREFILL_RANGE], | ||
| self.example_batch["inputs_position"][:, :PREFILL_RANGE], | ||
| model_mode=MODEL_MODE_PREFILL, | ||
| decoder_segment_ids=self.example_batch["inputs_segmentation"][:, :PREFILL_RANGE], | ||
| enable_dropout=False, | ||
| rngs={"aqt": self.rng}, | ||
| mutable=["cache"], | ||
| ) | ||
| # Verify partial prefill exactly matches full training logits | ||
| np.testing.assert_allclose( | ||
| full_train_logits[:, :PREFILL_RANGE, :], | ||
| partial_prefill_logits, | ||
| rtol=1e-01, | ||
| atol=1e-01, | ||
| ) | ||
|
|
||
| # 3. Perform an autoregressive decoding step using the updated KV cache. | ||
| idx = PREFILL_RANGE | ||
| ids_idx = self.example_batch["inputs"][:, idx : idx + 1] | ||
| decoder_positions_idx = self.example_batch["inputs_position"][:, idx : idx + 1] | ||
| prefill_transformer_vars.update(partial_cache) | ||
| ar_logits, _ = prefill_model.apply( | ||
| prefill_transformer_vars, | ||
| ids_idx, | ||
| decoder_positions_idx, | ||
| model_mode=MODEL_MODE_AUTOREGRESSIVE, | ||
| enable_dropout=False, | ||
| rngs={"aqt": self.rng}, | ||
| mutable=["cache"], | ||
| ) | ||
| # Verify autoregressive decoding exactly matches full training logits at the decoded position | ||
| np.testing.assert_allclose( | ||
| full_train_logits[:, idx : idx + 1, :], | ||
| ar_logits, | ||
| rtol=1e-01, | ||
| atol=1e-01, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a way to set global batch size to 1 regardless of device count.