diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index bb18a81a5f..c38f152c7a 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -87,6 +87,8 @@ class QuantizationType(str, Enum): INT4 = "int4" INT8 = "int8" INTMP = "intmp" + FP8_E5M2 = "fp8_e5m2" + FP8_E4M3 = "fp8_e4m3" FP8 = "fp8" NANOO_FP8 = "nanoo_fp8" FP8_NANO_V2 = "fp8_nanoo" diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 57c54cac21..ae98ce8f86 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -759,64 +759,37 @@ def get_fp8_full_qwix_rule_w_sparsity(config: Config): def get_quantization_rule(config: Config): + + """Returns a list of qwix.QtRule from `dtype`.""" + def make_qt_rule(dtype) -> list[qwix.QtRule]: + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=dtype, + act_qtype=dtype, + bwd_qtype=dtype, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] + match config.quantization: case "int4": - return [ - qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.int4, - act_qtype=jnp.int4, - bwd_qtype=jnp.int4, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) - ] + return make_qt_rule(jnp.int4) + case "int8": - return [ - qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.int8, - act_qtype=jnp.int8, - bwd_qtype=jnp.int8, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) - ] - case "fp8": - return [ - qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) - ] + return make_qt_rule(jnp.int8) + + case "fp8_e5m2": + return make_qt_rule(jnp.float8_e5m2) + + case "fp8" | "fp8_e4m3" | "fp8_gpu" | "fp8_nanoo": + return make_qt_rule(jnp.float8_e4m3fn) + case "fp8_full": return get_fp8_full_qwix_rule_w_sparsity(config) case "fp8_gpu": - return [ - qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) - ] - case "fp8_nanoo": - return [ - qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) - ] + return make_qt_rule(jnp.float8_e4m3fn) case "": return None @@ -824,13 +797,7 @@ def get_quantization_rule(config: Config): def get_qt_provider(config): """Get quantization rules based on the config.""" match config.quantization: - case "int8": - return qwix.QtProvider(get_quantization_rule(config)) - case "int4": - return qwix.QtProvider(get_quantization_rule(config)) - case "fp8": - return qwix.QtProvider(get_quantization_rule(config)) - case "fp8_full": + case "int4" | "int8" | "fp8" | "fp8_e5m2" | "fp8_e4m3" | "fp8_full": return qwix.QtProvider(get_quantization_rule(config)) case "fp8_gpu": return NvidaFp8Provider(get_quantization_rule(config))