Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/maxtext/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from maxtext.layers.attentions import AttentionOp, KVQuant
from maxtext.layers.initializers import Initializer, NdInitializer, nd_dense_init
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.inference import kvcache
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down Expand Up @@ -236,6 +237,11 @@ def __init__(
self.value_axis_names = value_axis_names
self.out_axis_names = out_axis_names
self.rngs = rngs
self.model_mode = model_mode
self.prefill_cache_axis_order = (1, 2, 0, 3)
self.ar_cache_axis_order = (1, 2, 0, 3)
self.use_ragged_attention = False
self.KVCache_0 = self.init_kv_caches(inputs_kv_shape=feature_dim) if self.model_mode != MODEL_MODE_TRAIN else None
if self.fused_qkv:
self.qkv_proj = self.create_projection_layer(
feature_dim, (3, self.num_heads, self.head_dim), ("embed", "qkv", "heads", "kv")
Expand Down Expand Up @@ -299,13 +305,48 @@ def projection(self, projection_layer: Any, inputs: Array) -> Array:
proj = projection_layer(inputs)
return proj

def init_kv_caches(self, inputs_kv_shape: tuple[int, ...]):
batch_size, _, _ = inputs_kv_shape
placeholder_seq_len = 1

return kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=placeholder_seq_len,
value_seq_len=placeholder_seq_len,
key_heads=self.num_heads,
value_heads=self.num_heads,
key_head_size=self.head_dim,
value_head_size=self.head_dim,
dtype=self.dtype,
kv_quant=self.kv_quant,
prefill_cache_axis_order=self.prefill_cache_axis_order,
ar_cache_axis_order=self.ar_cache_axis_order,
use_chunked_prefill=self.config.use_chunked_prefill,
model_mode=self.model_mode,
rngs=self.rngs,
)

def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous_chunk):
prefill_kv_cache, ar_kv_cache = self.KVCache_0(
key=key,
value=value,
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
use_ragged_attention=self.use_ragged_attention,
previous_chunk=previous_chunk,
)
return [prefill_kv_cache, ar_kv_cache]

def __call__(
self,
inputs_q: Array,
decoder_segment_ids: Array | None = None,
*,
deterministic: bool = False,
model_mode: str = MODEL_MODE_TRAIN,
previous_chunk: Any = None,
kv_cache: Array | None = None,
attention_metadata: dict[str, Any] | None = None,
):
Expand All @@ -328,7 +369,11 @@ def __call__(
value = nn.with_logical_constraint(value, self.value_axis_names)
value = checkpoint_name(value, "value_proj")

out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode)
cached_values = [None, None]
if model_mode != MODEL_MODE_TRAIN:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)

out = self.attention_op(query, key, value, decoder_segment_ids, None, model_mode, cached_values)

out = nn.with_logical_constraint(out, self.out_axis_names)

Expand Down Expand Up @@ -448,6 +493,7 @@ def __call__(
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
deterministic=deterministic,
previous_chunk=previous_chunk,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
)
Expand Down
81 changes: 78 additions & 3 deletions tests/unit/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Tests for GPT3. """
"""Tests for GPT3."""

import sys
import unittest
Expand All @@ -21,11 +21,12 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.configs import pyconfig
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
from maxtext.layers import quantizations
from maxtext.models import models
from maxtext.utils import maxtext_utils
from tests.utils.test_helpers import get_test_config_path
import numpy as np
import pytest


Expand Down Expand Up @@ -61,10 +62,11 @@ def setUp(self):
enable_checkpointing=False,
model_name="gpt3-52k",
dtype="float32",
per_device_batch_size=1.0 / jax.device_count(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a way to set global batch size to 1 regardless of device count.

)
self.rng = jax.random.PRNGKey(1234)

devices_array = maxtext_utils.create_device_mesh(self.cfg)
devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we only use one device for testing? No sharding involved?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the KV cache test. Running everything on a single device eliminates communication and ensures the test is fully deterministic.

mesh = Mesh(devices_array, self.cfg.mesh_axes)
quant = quantizations.configure_quantization(self.cfg)
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
Expand Down Expand Up @@ -106,3 +108,76 @@ def test_logits_numerically(self):
jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03),
msg=f"per_example_xent:\n{per_example_xent}\n\nper_example_xent_truth:\n{per_example_xent_truth}",
)

@pytest.mark.tpu_only
def test_prefill_and_autoregress(self):
"""Verifies that GPT-3 attention correctly initializes and updates the KV cache during decoding."""
PREFILL_RANGE = 2
devices_array = maxtext_utils.create_device_mesh(self.cfg, devices=[jax.devices()[0]])
mesh = Mesh(devices_array, self.cfg.mesh_axes)
quant = quantizations.configure_quantization(self.cfg)
prefill_model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)

# 0. Compute baseline full training logits for numerical equivalence comparison.
full_train_logits, _ = self.model.apply(
self.model_vars,
self.example_batch["inputs"],
self.example_batch["inputs_position"],
decoder_segment_ids=self.example_batch["inputs_segmentation"],
enable_dropout=False,
rngs={"dropout": self.rng, "aqt": self.rng},
mutable="intermediates",
)

# 1. Initialize model variables and KV cache structures in prefill mode.
prefill_transformer_vars = prefill_model.init(
{"params": self.rng, "aqt": self.rng},
self.example_batch["inputs"],
self.example_batch["inputs_position"],
model_mode=MODEL_MODE_PREFILL,
decoder_segment_ids=self.example_batch["inputs_segmentation"],
enable_dropout=False,
)
# Replace zero initializers with normal distribution to ensure strong numerical test cases
prefill_transformer_vars = init_random_model_vars(prefill_model, self.rng, self.example_batch)

# 2. Execute a partial prefill pass to populate the KV cache.
partial_prefill_logits, partial_cache = prefill_model.apply(
prefill_transformer_vars,
self.example_batch["inputs"][:, :PREFILL_RANGE],
self.example_batch["inputs_position"][:, :PREFILL_RANGE],
model_mode=MODEL_MODE_PREFILL,
decoder_segment_ids=self.example_batch["inputs_segmentation"][:, :PREFILL_RANGE],
enable_dropout=False,
rngs={"aqt": self.rng},
mutable=["cache"],
)
# Verify partial prefill exactly matches full training logits
np.testing.assert_allclose(
full_train_logits[:, :PREFILL_RANGE, :],
partial_prefill_logits,
rtol=1e-01,
atol=1e-01,
)

# 3. Perform an autoregressive decoding step using the updated KV cache.
idx = PREFILL_RANGE
ids_idx = self.example_batch["inputs"][:, idx : idx + 1]
decoder_positions_idx = self.example_batch["inputs_position"][:, idx : idx + 1]
prefill_transformer_vars.update(partial_cache)
ar_logits, _ = prefill_model.apply(
prefill_transformer_vars,
ids_idx,
decoder_positions_idx,
model_mode=MODEL_MODE_AUTOREGRESSIVE,
enable_dropout=False,
rngs={"aqt": self.rng},
mutable=["cache"],
)
# Verify autoregressive decoding exactly matches full training logits at the decoded position
np.testing.assert_allclose(
full_train_logits[:, idx : idx + 1, :],
ar_logits,
rtol=1e-01,
atol=1e-01,
)
Loading