diff --git a/docs/reference/core_concepts/quantization.md b/docs/reference/core_concepts/quantization.md index dae117a85a..608f828ff2 100644 --- a/docs/reference/core_concepts/quantization.md +++ b/docs/reference/core_concepts/quantization.md @@ -127,6 +127,9 @@ model = qwix.quantize_model(model, qwix.QtProvider(rule)) ### AQT Quantization +> [!WARNING] +> **DEPRECATION NOTICE**: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting `use_qwix_quantization=True`. + If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e). #### `quantization` values for AQT diff --git a/docs/reference/models/supported_models_and_architectures.md b/docs/reference/models/supported_models_and_architectures.md index 97268d98cc..57da9b7e33 100644 --- a/docs/reference/models/supported_models_and_architectures.md +++ b/docs/reference/models/supported_models_and_architectures.md @@ -74,7 +74,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs. - **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it. -- **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). +- **Quantization**: MaxText supports quantization via both the Qwix and AQT (deprecated) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). - **MoE**: The Mixture-of-Experts implementation features dropless routing with efficient kernels including Megablox, `jax.lax.ragged_dot`, and Tokamax Ragged Dot. - **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens. - **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently. diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5667b6ec00..78f7a67f32 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -143,7 +143,7 @@ save_quantized_params_path: "" # when left as is, corresponds to training # accepted values are "inference" model_call_mode: "" -use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix. +use_qwix_quantization: false # [DEPRECATED: AQT will be removed in a future release. It is strongly recommended to set use_qwix_quantization to true] whether to use qwix for quantization. if set to true, the model will be quantized using qwix. use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True. # quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80 weight_quantization_calibration_method: "absmax" diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 57c54cac21..8444ca6b16 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -41,6 +41,7 @@ from maxtext.common.common_types import DType, Config from maxtext.inference.kvcache import KVQuant +from maxtext.utils import max_logging # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -652,6 +653,10 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"): return TransformerEngineQuantization(config) quant_mode = get_quant_mode(quant_mode_str) replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False + max_logging.log( + "WARNING: AQT quantization is deprecated and will be removed in a future release. " + "Please migrate to Qwix by setting use_qwix_quantization=True." + ) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale) return None diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index 6b4a3c6295..ea032baaf2 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -149,6 +149,7 @@ def test_configure_quantization_replicate_scale(self): quant = _configure_quantization(quant_str="int8", mode_str=quant_mode, replicate_scale=True) self.assertEqual(quant.replicate_scale, True) + @pytest.mark.cpu_only def test_configure_quantization_is_int8(self): for quant_mode in ["train", "serve", "convert"]: quant = _configure_quantization(quant_str="int8", mode_str=quant_mode)