Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/reference/core_concepts/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ model = qwix.quantize_model(model, qwix.QtProvider(rule))

### AQT Quantization

> [!WARNING]
> **DEPRECATION NOTICE**: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting `use_qwix_quantization=True`.

If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e).

#### `quantization` values for AQT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.

- **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
- **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
- **Quantization**: MaxText supports quantization via both the Qwix and AQT (deprecated) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
- **MoE**: The Mixture-of-Experts implementation features dropless routing with efficient kernels including Megablox, `jax.lax.ragged_dot`, and Tokamax Ragged Dot.
- **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
- **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently.
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ save_quantized_params_path: ""
# when left as is, corresponds to training
# accepted values are "inference"
model_call_mode: ""
use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
use_qwix_quantization: false # [DEPRECATED: AQT will be removed in a future release. It is strongly recommended to set use_qwix_quantization to true] whether to use qwix for quantization. if set to true, the model will be quantized using qwix.
use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True.
# quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80
weight_quantization_calibration_method: "absmax"
Expand Down
5 changes: 5 additions & 0 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from maxtext.common.common_types import DType, Config
from maxtext.inference.kvcache import KVQuant
from maxtext.utils import max_logging

# Params used to define mixed precision quantization configs
DEFAULT = "__default__" # default config
Expand Down Expand Up @@ -652,6 +653,10 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"):
return TransformerEngineQuantization(config)
quant_mode = get_quant_mode(quant_mode_str)
replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False
max_logging.log(
"WARNING: AQT quantization is deprecated and will be removed in a future release. "
"Please migrate to Qwix by setting use_qwix_quantization=True."
)
return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale)
return None

Expand Down
1 change: 1 addition & 0 deletions tests/unit/quantizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_configure_quantization_replicate_scale(self):
quant = _configure_quantization(quant_str="int8", mode_str=quant_mode, replicate_scale=True)
self.assertEqual(quant.replicate_scale, True)

@pytest.mark.cpu_only
def test_configure_quantization_is_int8(self):
for quant_mode in ["train", "serve", "convert"]:
quant = _configure_quantization(quant_str="int8", mode_str=quant_mode)
Expand Down
Loading