From d9e458176f17cafcae3d415d97fecce2e354100a Mon Sep 17 00:00:00 2001 From: Sarun Singla Date: Thu, 14 May 2026 22:08:18 +0000 Subject: [PATCH 1/3] Deprecate AQT quantization in MaxText --- docs/reference/core_concepts/quantization.md | 3 +++ docs/reference/models/supported_models_and_architectures.md | 4 ++-- src/maxtext/configs/base.yml | 2 +- src/maxtext/layers/quantizations.py | 2 ++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/reference/core_concepts/quantization.md b/docs/reference/core_concepts/quantization.md index dae117a85a..608f828ff2 100644 --- a/docs/reference/core_concepts/quantization.md +++ b/docs/reference/core_concepts/quantization.md @@ -127,6 +127,9 @@ model = qwix.quantize_model(model, qwix.QtProvider(rule)) ### AQT Quantization +> [!WARNING] +> **DEPRECATION NOTICE**: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting `use_qwix_quantization=True`. + If `use_qwix_quantization` is `False` or not set, you can still apply quantization using the AQT library by setting the `quantization` flag. You can read more about AQT on this [Google Cloud blog](https://cloud.google.com/blog/products/compute/accurate-quantized-training-aqt-for-tpu-v5e). #### `quantization` values for AQT diff --git a/docs/reference/models/supported_models_and_architectures.md b/docs/reference/models/supported_models_and_architectures.md index fb80002df5..77b7fbedbe 100644 --- a/docs/reference/models/supported_models_and_architectures.md +++ b/docs/reference/models/supported_models_and_architectures.md @@ -10,7 +10,7 @@ MaxText is an open-source, high-performance LLM framework written in Python/JAX. - **Supported Precisions**: FP32, BF16, INT8, and FP8. - **Ahead-of-Time Compilation (AOT)**: For faster model development/prototyping and earlier OOM detection. -- **Quantization**: Via **Qwix** (recommended) and AQT. See Quantization [Guide](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/quantization.html). +- **Quantization**: Via **Qwix** (recommended) and AQT (deprecated). See Quantization [Guide](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/reference/core_concepts/quantization.html). - **Diagnostics**: Structured error context via **`cloud_tpu_diagnostics`** (filters stack traces to user code), simple logging via `max_logging`, profiling in **XProf**, and visualization in **TensorBoard**. - **Multi-Token Prediction (MTP)**: Enables token efficient training with multi-token prediction. - **Elastic Training**: Fault-tolerant and dynamic scale-up/scale-down on Cloud TPUs with Pathways. @@ -74,7 +74,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs. - **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it. -- **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). +- **Quantization**: MaxText supports quantization via both the Qwix and AQT (deprecated) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ). - **MoE**: The Mixture-of-Experts implementation features dropless routing with efficient kernels including Megablox, `jax.lax.ragged_dot`, and Tokamax Ragged Dot. - **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens. - **Long-Context Optimizations**: Implements various efficient attention mechanisms, including: Grouped-Query Attention (GQA), Sliding-Window Attention (SWA), Local–Global interleaved attention, Multi-Head Latent Attention (MLA). They reduce the KV-cache size, making it possible to handle long contexts efficiently. diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index ecf03133cc..62df56fa48 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -142,7 +142,7 @@ save_quantized_params_path: "" # when left as is, corresponds to training # accepted values are "inference" model_call_mode: "" -use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix. +use_qwix_quantization: false # [DEPRECATED: AQT will be removed in a future release. It is strongly recommended to set use_qwix_quantization to true] whether to use qwix for quantization. if set to true, the model will be quantized using qwix. use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True. # quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80 weight_quantization_calibration_method: "absmax" diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 57c54cac21..faa87f842a 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -41,6 +41,7 @@ from maxtext.common.common_types import DType, Config from maxtext.inference.kvcache import KVQuant +from maxtext.utils import max_logging # Params used to define mixed precision quantization configs DEFAULT = "__default__" # default config @@ -652,6 +653,7 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"): return TransformerEngineQuantization(config) quant_mode = get_quant_mode(quant_mode_str) replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False + max_logging.log("WARNING: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting use_qwix_quantization=True.") return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale) return None From 2834b6ac6c83b5065b5a70a2a916aae2dcaa5c76 Mon Sep 17 00:00:00 2001 From: Sarun Singla Date: Fri, 15 May 2026 02:43:27 +0000 Subject: [PATCH 2/3] Fix pylint line too long error in quantizations.py --- src/maxtext/layers/quantizations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index faa87f842a..8444ca6b16 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -653,7 +653,10 @@ def configure_quantization(config: Config, quant_mode_str: str = "train"): return TransformerEngineQuantization(config) quant_mode = get_quant_mode(quant_mode_str) replicate_scale = config.replicate_quant_scale if config.replicate_quant_scale else False - max_logging.log("WARNING: AQT quantization is deprecated and will be removed in a future release. Please migrate to Qwix by setting use_qwix_quantization=True.") + max_logging.log( + "WARNING: AQT quantization is deprecated and will be removed in a future release. " + "Please migrate to Qwix by setting use_qwix_quantization=True." + ) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode, replicate_scale=replicate_scale) return None From 5867e24226dcba28d804c4a3ba595a15d39f56a2 Mon Sep 17 00:00:00 2001 From: Sarun Singla Date: Fri, 15 May 2026 03:28:45 +0000 Subject: [PATCH 3/3] Add cpu_only marker to test_configure_quantization_is_int8 for Codecov test coverage --- tests/unit/quantizations_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index 6b4a3c6295..ea032baaf2 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -149,6 +149,7 @@ def test_configure_quantization_replicate_scale(self): quant = _configure_quantization(quant_str="int8", mode_str=quant_mode, replicate_scale=True) self.assertEqual(quant.replicate_scale, True) + @pytest.mark.cpu_only def test_configure_quantization_is_int8(self): for quant_mode in ["train", "serve", "convert"]: quant = _configure_quantization(quant_str="int8", mode_str=quant_mode)