diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 18e5463df3..bcda57d48b 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -119,7 +119,7 @@ jobs: needs: analyze_code_changes # Run if either tests or notebooks need to run if: | - needs.analyze_code_changes.outputs.run_tests == 'true' || + needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' uses: ./.github/workflows/build_package.yml with: @@ -318,4 +318,4 @@ jobs: actions: 'read' with: failed_run_id: '${{ github.run_id }}' - secrets: inherit \ No newline at end of file + secrets: inherit diff --git a/.github/workflows/gemini-dispatch.yml b/.github/workflows/gemini-dispatch.yml index 09311a93c6..92935df1e7 100644 --- a/.github/workflows/gemini-dispatch.yml +++ b/.github/workflows/gemini-dispatch.yml @@ -24,7 +24,7 @@ defaults: jobs: debugger: # Debug mode: with a repository variable called DEBUG to true - if: |- + if: |- ${{ fromJSON(vars.DEBUG || vars.ACTIONS_STEP_DEBUG || false) }} runs-on: 'ubuntu-latest' permissions: @@ -39,7 +39,7 @@ jobs: DEBUG_event__pull_request__author_association: '${{ github.event.pull_request.author_association }}' DEBUG_event__review__author_association: '${{ github.event.review.author_association }}' DEBUG_event: '${{ toJSON(github.event) }}' - run: |- + run: |- env | grep '^DEBUG_' dispatch: diff --git a/.github/workflows/gemini-investigate.yml b/.github/workflows/gemini-investigate.yml index 0e650b49e8..d8b5f226ed 100644 --- a/.github/workflows/gemini-investigate.yml +++ b/.github/workflows/gemini-investigate.yml @@ -36,7 +36,7 @@ jobs: PR_NUMBER: ${{ github.event.workflow_run.pull_requests[0].number || github.event.pull_request.number || github.event.issue.number }} run: | mkdir -p .gemini - + # Determine target run ID if [ -z "$RUN_ID" ]; then # If SHA/BRANCH are missing (e.g. on issue_comment event), fetch them from PR @@ -45,22 +45,22 @@ jobs: SHA=$(gh pr view "$PR_NUMBER" --repo "$REPO" --json headRefOid --jq '.headRefOid' 2>/dev/null || true) BRANCH=$(gh pr view "$PR_NUMBER" --repo "$REPO" --json headRefName --jq '.headRefName' 2>/dev/null || true) fi - + # Fallback to finding the latest failed run for this PR's specific commit if [ -n "$SHA" ]; then echo "Searching for failed runs for commit: $SHA" RUN_ID=$(gh run list --workflow "MaxText Package Tests" --status failure --commit "$SHA" --limit 1 --json databaseId --jq '.[0].databaseId' --repo "$REPO") fi - + # Fallback to branch if commit-specific run wasn't found if [ -z "$RUN_ID" ] && [ -n "$BRANCH" ]; then echo "Searching for failed runs on branch: $BRANCH" RUN_ID=$(gh run list --workflow "MaxText Package Tests" --status failure --branch "$BRANCH" --limit 1 --json databaseId --jq '.[0].databaseId' --repo "$REPO") fi fi - + echo "Gathering logs for failed run: $RUN_ID" - + if [ -n "$RUN_ID" ]; then # Retrieve only the failing lines/jobs to avoid token limit overhead gh run view "$RUN_ID" --log-failed --repo "$REPO" > .gemini/failed_logs.txt || true diff --git a/.github/workflows/gemini-invoke.yml b/.github/workflows/gemini-invoke.yml index 4a59bc7fcd..0242f1e062 100644 --- a/.github/workflows/gemini-invoke.yml +++ b/.github/workflows/gemini-invoke.yml @@ -39,7 +39,7 @@ jobs: permission-pull-requests: 'write' - name: 'Run Gemini CLI' - # Trigger Gemini with context + # Trigger Gemini with context id: 'run_gemini' uses: 'google-github-actions/run-gemini-cli@main' env: diff --git a/.github/workflows/gemini-review.yml b/.github/workflows/gemini-review.yml index e7d11b7dde..92c777a5ff 100644 --- a/.github/workflows/gemini-review.yml +++ b/.github/workflows/gemini-review.yml @@ -60,7 +60,7 @@ jobs: ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' - name: 'Run Gemini pull request review' - # reviews code with detailed set of instructions for the Gemini + # reviews code with detailed set of instructions for the Gemini uses: 'google-github-actions/run-gemini-cli@v0' id: 'gemini_pr_review' env: @@ -71,7 +71,7 @@ jobs: PULL_REQUEST_NUMBER: '${{ github.event.pull_request.number || github.event.issue.number }}' REPOSITORY: '${{ github.repository }}' ADDITIONAL_CONTEXT: '${{ inputs.additional_context }}' - with: + with: gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}' gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}' gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}' diff --git a/.github/workflows/pypi_release.yml b/.github/workflows/pypi_release.yml index 0eba10c310..33854d2f68 100644 --- a/.github/workflows/pypi_release.yml +++ b/.github/workflows/pypi_release.yml @@ -42,7 +42,7 @@ jobs: needs: [release_approval] uses: ./.github/workflows/build_and_test_maxtext.yml secrets: inherit - + publish_maxtext_package_to_pypi: name: Publish MaxText package to PyPI needs: [build_and_test_maxtext_package] diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index 7d868e8d5c..d0db6f63cb 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -73,7 +73,7 @@ jobs: # 2. Install MaxText package and all the post training dependencies uv pip install ${maxtext_wheel}[tpu-post-train] --resolution=lowest install_tpu_post_train_extra_deps - + python3 -m pip freeze - name: Run Post-Training Notebooks shell: bash @@ -81,7 +81,7 @@ jobs: HF_TOKEN: ${{ secrets.HF_TOKEN }} MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }} # TODO: Fix evaluation in sft_qwen3_demo.ipynb and remove this env variable - RUN_EVALUATION: "False" + RUN_EVALUATION: "false" run: | if [ "${MAXTEXT_INSTALLED}" == "true" ]; then # Move to the directory where code is baked into the image. See the Dockerfile. diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index d737ab5345..1c4b2f4eae 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -86,9 +86,9 @@ jobs: TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} MAXTEXT_PACKAGE_EXTRA: >- - ${{ - !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' - || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) + ${{ + !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' + || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) options: ${{ inputs.container_resource_option }} @@ -171,7 +171,7 @@ jobs: else SPLIT_ARGS="" fi - + # Setup substitution: If manually updating HLO, skip tests execution and run only the update script instead! if [ "${INPUTS_IS_UPDATE_HLO}" == "true" ]; then python3 tests/utils/update_hlo_references.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c5caf6665..8fcbf64255 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,3 +61,9 @@ repos: additional_dependencies: [mdformat-myst, mdformat-ruff] files: (docs/.) exclude: docs/guides/checkpointing_solutions.md|docs/guides.md + + - repo: https://github.com/adrienverge/yamllint + rev: v1.35.0 + hooks: + - id: yamllint + types: [yaml] diff --git a/.yamllint b/.yamllint new file mode 100644 index 0000000000..87dbfecade --- /dev/null +++ b/.yamllint @@ -0,0 +1,8 @@ +extends: relaxed +rules: + line-length: disable + comments: disable + indentation: disable + commas: disable + colons: disable + empty-lines: disable diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 670d155974..656dc8ec0c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -21,10 +21,10 @@ run_name: "" model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! -override_model_config: False # When set to true allows overriding model parameters via CLI (or kwargs or env vars) for the purpose of debugging/testing. -override_logical_axis_rules: False # When set overrides logical axis rules instead of merging them. +override_model_config: false # When set to true allows overriding model parameters via CLI (or kwargs or env vars) for the purpose of debugging/testing. +override_logical_axis_rules: false # When set overrides logical axis rules instead of merging them. debug: - rl: False # RL-specific debugging + rl: false # RL-specific debugging normalization_layer_epsilon: 1.e-05 # epsilon value for rmsnorm, layernorm. @@ -52,20 +52,20 @@ load_full_state_path: "" # If enable_checkpointing is true, an asynchronous checkpointer will be used if # async_checkpointing is true, else a synchronous one is used. If you have # problems with the checkpointer we recommend trying the synchronous one. -enable_checkpointing: True -save_checkpoint_on_completion: True -async_checkpointing: True +enable_checkpointing: true +save_checkpoint_on_completion: true +async_checkpointing: true checkpoint_period: 10_000 max_num_checkpoints_to_keep: None -enable_continuous_checkpointing: False +enable_continuous_checkpointing: false # enables one replica to read the ckpt then broadcast to the rest -enable_single_replica_ckpt_restoring: False +enable_single_replica_ckpt_restoring: false # Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://) checkpoint_todelete_subdir: None # Full path to move checkpoints to before deletion. checkpoint_todelete_full_path: None -force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? +force_unroll: false # during generate_param_only_checkpoint should we unroll the loop? # checkpointing using orbax has two important parameters: array driver # and its underlying storage - the kvstore (preferably ocdbt) @@ -73,24 +73,24 @@ force_unroll: False # during generate_param_only_checkpoint should we unroll the # large arrays into small physical files (<2GB) can speed up distributed and over # the network loading enormously checkpoint_storage_target_data_file_size_bytes: 2147483648 -checkpoint_storage_use_ocdbt: True -checkpoint_storage_use_zarr3: True +checkpoint_storage_use_ocdbt: true +checkpoint_storage_use_zarr3: true # larger models requires higher concurrent GB for I/O # default concurrent gb for PytreeCheckpointHandler is 96GB checkpoint_storage_concurrent_gb: 96 # Bool flag for enabling Orbax v1. -enable_orbax_v1: False +enable_orbax_v1: false # function for processing loaded checkpoint dict into a format maxtext can understand. (for other formats, i.e. safetensors) checkpoint_conversion_fn: none # optional checkpoint context to use for loading. options: "orbax", "safetensors" source_checkpoint_layout: "orbax" -# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing -colocated_python_checkpointing: False +# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing +colocated_python_checkpointing: false # enables autocheckpoint, which saves a checkpoint at the preemption step. -enable_autocheckpoint: False +enable_autocheckpoint: false ############################### end checkpointing ################################## @@ -144,7 +144,7 @@ save_quantized_params_path: "" # accepted values are "inference" model_call_mode: "" use_qwix_quantization: false # whether to use qwix for quantization. if set to true, the model will be quantized using qwix. -use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is True. +use_manual_quantization: false # a flag if to use manual quantization for batch split. Only used if use_batch_split_schedule is true. # quantization calibration method used for weights and activations. supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#l70-l80 weight_quantization_calibration_method: "absmax" act_quantization_calibration_method: "absmax" @@ -245,7 +245,7 @@ wo_tile_drhs_batch_seq: 512 wo_tile_drhs_embed_dim: 1024 wo_tile_drhs_mlp_dim: 1024 -merge_gating_gmm: False +merge_gating_gmm: false norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights. @@ -253,22 +253,22 @@ norm_topk_prob: false # boolean to enable the top-k probability normalization. q moe_fsdp_use_two_stage_all_gather: false # Shard the expert dimension of the MLP weights on the FSDP axis. # This configuration is recommended only when num_experts is a multiple of fsdp_parallelism -shard_exp_on_fsdp: False +shard_exp_on_fsdp: false # deepseek moe first_num_dense_layers: 0 # number of initial dense layers in the model shared_experts: 0 routed_scaling_factor: 1.0 # scaling factor for routing scores routed_score_func: "" # scoring function for routing -routed_bias: False # a flag if a learnable bias is added for routing +routed_bias: false # a flag if a learnable bias is added for routing routed_bias_update_rate: 0.0 # a flag indicate the update rate applied to the router bias term -mlp_bias: False # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture. +mlp_bias: false # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture. n_routing_groups: -1 # number of groups for routing, disabled by default topk_routing_group: -1 # number of top groups to route inputs. For EP, # Splits the batch to allow for better scheduling when using expert parallelism by overlapping the # all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers. -use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits. -batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True. +use_batch_split_schedule: false # a flag if splitting batch into micro-batches to hide communications that yields performance benefits. +batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is true. # For complex architectures like llama4 there are repeated sets of # inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope] @@ -294,16 +294,16 @@ pipeline_parallel_layers: -1 # Pipeline only this number of layers - for the rem # num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages. # Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices num_pipeline_microbatches: -1 -pipeline_delay_activation_forwarding: False # This delays the activation forwarding one loop iteration simplifying XLA's task of overlapping since +pipeline_delay_activation_forwarding: false # This delays the activation forwarding one loop iteration simplifying XLA's task of overlapping since # the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble, # and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay). -pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration. +pipeline_fsdp_ag_once: false # If set to true then all gather all of the weights over FSDP before the first pipeline iteration. # This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed # to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed # to every microbatch. This is similar to zero-1 sharding, since we also don't need to all gather the FSDP weights in the backward pass. # An alternative to setting this to true may be to replace any FSDP with DP and use optimizer offloading if necessary. -pipeline_fsdp_ag_per_repeat: False +pipeline_fsdp_ag_per_repeat: false # Pipeline weight prefetching per repeat is an advanced SPMD pipeline parallelism improvement technique # When enabled, it prefetches necessary weight gathering ahead of microbatched computation, therefore reducing collectives @@ -314,11 +314,11 @@ pipeline_fsdp_ag_per_repeat: False # settings below of scanning and setting a remat policy only over the pipeline iterations. # It may be useful to do the reverse when the layers_per_stage is very large. # The below settings only have effect when using pipeline parallelism. -scan_pipeline_iterations: True -scan_pipeline_repeats: False -scan_layers_per_stage: False -set_remat_policy_on_pipeline_iterations: True -set_remat_policy_on_layers_per_stage: False +scan_pipeline_iterations: true +scan_pipeline_repeats: false +scan_layers_per_stage: false +set_remat_policy_on_pipeline_iterations: true +set_remat_policy_on_layers_per_stage: false # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', @@ -348,46 +348,46 @@ mla_kv: 'remat' attention_out: 'remat' engram: 'remat' -optimizer_memory_host_offload: False -parameter_memory_host_offload: False -scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. +optimizer_memory_host_offload: false +parameter_memory_host_offload: false +scan_layers: true # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. param_scan_axis: 1 # The attention parameter dictates the specific algorithm/methodology used to compute the attention scores # The attention_type parameter determines the variants of attention, e.g. global or local_sliding attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla -share_kv_projections: False # Note: Not compatible with attention_type='mla' -attention_bias: False # If True, adds a learnable bias to the query, key, and value projections -attention_sink: False +share_kv_projections: false # Note: Not compatible with attention_type='mla' +attention_bias: false # If true, adds a learnable bias to the query, key, and value projections +attention_sink: false sliding_window_size: 0 chunk_attn_window_size: 0 attn_logits_soft_cap: 0.0 final_logits_soft_cap: 0.0 z_loss_multiplier: 0.0 -use_post_attn_norm: False -use_post_ffw_norm: False -v_norm_with_scale: True -qk_norm_with_scale: True -mla_naive_kvcache: True +use_post_attn_norm: false +use_post_ffw_norm: false +v_norm_with_scale: true +qk_norm_with_scale: true +mla_naive_kvcache: true # Adding Mixture of Block Attention Support (MoBA): https://github.com/MoonshotAI/MoBA/blob/master/MoBA_Tech_Report.pdf -moba: False +moba: false moba_chunk_size: 1024 moba_topk: 8 # DeepSeek Sparse Attention (DSA) # deepseek3.2 introduces indexer in MLA -use_indexer: False +use_indexer: false indexer_head_dim: 128 indexer_n_heads: 64 indexer_topk: 2048 # Determines the training strategy for the indexer: -# - False (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters. -# - True (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization. +# - false (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters. +# - true (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization. # Note: This is only active when `indexer_loss_scaling_factor` > 0. -indexer_sparse_training: False +indexer_sparse_training: false # Multiplier for the indexer KL divergence loss indexer_loss_scaling_factor: 0.0 @@ -399,12 +399,12 @@ qk_rope_head_dim: 64 v_head_dim: 128 # QK-Clip (Muon Clip) Configuration -use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) +use_qk_clip: false # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) # Combine matmuls for QKV and MLP -fused_qkv: False -fused_mlp: False +fused_qkv: false +fused_mlp: false record_internal_nn_metrics: 0 @@ -418,8 +418,8 @@ base_output_directory: "" # During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage. # See more details on https://github.com/google/orbax/tree/main/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing. # Example for enabling multi-tier checkpointing -# enable_multi_tier_checkpointing=True local_checkpoint_directory="/local" local_checkpoint_period=20 multi_tier_checkpointing_backup_interval_minutes=20 -enable_multi_tier_checkpointing: False +# enable_multi_tier_checkpointing=true local_checkpoint_directory="/local" local_checkpoint_period=20 multi_tier_checkpointing_backup_interval_minutes=20 +enable_multi_tier_checkpointing: false # The interval to backup local checkpoints to the persistent storage(GCS bucket) in minutes. # It should be a positive number when enabling multi-tier checkpointing. @@ -430,16 +430,16 @@ multi_tier_checkpointing_backup_interval_minutes: 0 mtc_data_parallelism: 0 -# Whether to enable emergency checkpoint. If True, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified. +# Whether to enable emergency checkpoint. If true, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified. # Emergency checkpoint is an experimental Orbax feature that: periodically saves to persistent storage and, with a larger invertal, saves to a local directory. # During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage. # See more details on https://github.com/google/orbax/tree/main/checkpoint/orbax/checkpoint/experimental/emergency. -enable_emergency_checkpoint: False +enable_emergency_checkpoint: false -# It should be specified when and only when `enable_emergency_checkpoint` is True. Or when `enable_multi_tier_checkpointing` is True. +# It should be specified when and only when `enable_emergency_checkpoint` is true. Or when `enable_multi_tier_checkpointing` is true. local_checkpoint_directory: "" -# It should be a positive number when and only when `enable_emergency_checkpoint` or `enable_multi_tier_checkpointing` is True. +# It should be a positive number when and only when `enable_emergency_checkpoint` or `enable_multi_tier_checkpointing` is true. local_checkpoint_period: 0 # Jax cache directory @@ -449,9 +449,9 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc. -internal_compile: False +internal_compile: false internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile. -compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536" +compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536" # Parallelism shard_mode: "auto" # can be either auto or explicit @@ -564,8 +564,8 @@ logical_axis_rules: [ # ========================================== # Deprecated / Scheduled for Removal # ========================================== - ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], - ['embed_tensor_transpose', ['tensor_transpose']], + ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], + ['embed_tensor_transpose', ['tensor_transpose']], ['exp_with_fsdp', 'fsdp'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details @@ -610,7 +610,7 @@ ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 # Enable ZeRO-1 optimizer sharding over data axis -shard_optimizer_over_data: False +shard_optimizer_over_data: false # Unless explicitly specified, the number of TPU slices is automatically determined. It should only be set for # disaggregated reinforcement learning workloads using multiple slices. For ahead of time compilation, @@ -630,17 +630,17 @@ tokenizer_path: "" # grain and tfds pipeline supports tokenizer_type: sentencepiece, huggingface, tiktoken # hf pipeline only supports huggingface type, and will ignore tokenizer_type flag tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface" -use_chat_template: False +use_chat_template: false chat_template_path: "" # path to chat template json file -tokenize_train_data: True # False if the dataset is pre-tokenized -tokenize_eval_data: True # False if the dataset is pre-tokenized -add_bos: True -add_eos: True -# If False, use chunking for long sequences instead of truncation. -# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline. +tokenize_train_data: true # false if the dataset is pre-tokenized +tokenize_eval_data: true # false if the dataset is pre-tokenized +add_bos: true +add_eos: true +# If false, use chunking for long sequences instead of truncation. +# Note: use_truncation=false is only available in grain's pretrain preprocessing pipeline. # See the TokenizeAndTrim and TokenizeAndChunk classes in # `src/maxtext/input_pipeline/_grain_tokenizer.py` for implementation details. -use_truncation: True +use_truncation: true # Dataset per_device_batch_size: 12.0 @@ -655,10 +655,10 @@ train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected train_image_column: 'image' eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" eval_image_column: 'image' -packing: True +packing: true num_epoch: 1 -generate_padding_batch_train: False -generate_padding_batch_eval: False +generate_padding_batch_train: false +generate_padding_batch_eval: false # Maximum number of segments that can be packed into a single sequence # This needs to be passed to TransformerEngine's DotProductAttention layer for packing # This also affects packing for grain, since TransformerEngine may crash or cause @@ -671,7 +671,7 @@ max_segments_per_seq: -1 # the final `per_device_batch_size`. For a clean ramp-up, the total range # (`per_device_batch_size` - `per_device_batch_size_start`) # should be evenly divisible by batch size increment. -enable_rampup_batch_size: False +enable_rampup_batch_size: false per_device_batch_size_start: 4.0 per_device_batch_size_increment: 2.0 # The target number of training samples to process during the ramp-up phase. @@ -679,14 +679,14 @@ per_device_batch_size_increment: 2.0 global_rampup_samples: 500 # direct preference optimization (DPO) -use_dpo: False +use_dpo: false dpo_label_smoothing: 0.0 dpo_beta: 0.1 # Supervised Fine-Tuning (SFT) -use_sft: False -# sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise -sft_train_on_completion_only: False +use_sft: false +# sft_train_on_completion_only=false trains on both prompt and completion tokens; trains only on completion tokens otherwise +sft_train_on_completion_only: false # dataset_type must be synthetic, hf, grain, tfds # details in: https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md @@ -732,9 +732,9 @@ grain_num_threads_eval: 16 grain_prefetch_buffer_size_eval: 500 grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources. grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord. -grain_use_elastic_iterator: False # For elastic training, set to this true and packing=False +grain_use_elastic_iterator: false # For elastic training, set to this true and packing=false # for using pathways -colocated_python_data_input: False # experimental feature, under testing +colocated_python_data_input: false # experimental feature, under testing # OLMo numpy pipeline (dataset_type=olmo_grain). Worker count, buffer size, # and shuffle seed reuse grain_worker_count / grain_per_worker_buffer_size / @@ -742,7 +742,7 @@ colocated_python_data_input: False # experimental feature, under testing olmo_index_path: '' # JSON from tools/data_generation/build_olmo_npy_index.py olmo_path_remap_from: '' # rewrite index paths starting with this prefix... olmo_path_remap_to: '' # ...to this one (e.g. gs://bucket/ -> /mnt/.../ for gcsfuse). -olmo_apply_ngram_filter: True # mask instances with repetitive n-grams (OLMo-core filter) +olmo_apply_ngram_filter: true # mask instances with repetitive n-grams (OLMo-core filter) # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps @@ -752,10 +752,10 @@ jax_distributed_initialization_timeout: 300 # This is the default timeout in htt # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. -skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. +skip_jax_distributed_system: false # If true we will not initialize the jax distributed system. # Currently the jax distributed is needed on cloud TPUs for async checkpointing. # However when run on google internal TPUs the coordination service is started automatically -# and we should set this to True so we won't try to initialize a second time manually. +# and we should set this to true so we won't try to initialize a second time manually. # Learning rate schedule structure depends on lr_schedule_type: # @@ -786,7 +786,7 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set max_target_length: 2048 # Maximum sequence length max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression prompt: "I love to" # Prompt for language model sampling. -load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory +load_from_prefill_dir: false # If true, decode.py doesn't "prefill" but just reads from directory prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory autoregressive_decode_assert: "" @@ -794,17 +794,17 @@ autoregressive_decode_assert: "" # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} profiler: "" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. -upload_all_profiler_results: False +upload_all_profiler_results: false # Skip first n steps for profiling, to omit things like compilation and to give # the iteration time a chance to stabilize. skip_first_n_steps_for_profiler: 1 # Profile for a small number of steps to avoid a large profile file size. profiler_steps: 5 -hide_profiler_step_metric: False -profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step. +hide_profiler_step_metric: false +profile_cleanly: true # If set to true, adds a block_until_ready on train state which aligns the profile for each step. profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps. # This is useful to debug scenarios where performance is changing. -enable_tpu_profiling_options: False +enable_tpu_profiling_options: false tpu_num_chips_to_profile_per_task: 1 tpu_num_sparse_core_tiles_to_trace: 1 tpu_num_sparse_cores_to_trace: 2 @@ -813,35 +813,35 @@ tpu_num_sparse_cores_to_trace: 2 # - create a managed ML diagnostics run with all the MaxText configs # - upload xplane profiling, if it is enabled. # - upload training metrics, at the defined log_period interval. -managed_mldiagnostics: False # Whether to enable the managed diagnostics +managed_mldiagnostics: false # Whether to enable the managed diagnostics managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. # Dump HLO and jaxpr options -dump_hlo: False +dump_hlo: false dump_step: -1 # Dump modules at the given step if set to a positive integer. dump_hlo_local_dir: "/tmp/xla_dump/" -dump_hlo_delete_local_after: True # Cleans local directory after its uploaded +dump_hlo_delete_local_after: true # Cleans local directory after its uploaded dump_hlo_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/xla_dump dump_hlo_local_module_name: "jit_train_step" # Filter saving modules locally by this string. Set to empty string to remove any filter. dump_hlo_module_name: "jit_train_step" # Filter uploading modules by this string. Set to empty string to remove any filter. dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_dump_hlo_module_re={dump_hlo_local_module_name} --xla_dump_large_constants" -dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0 +dump_hlo_upload_all: false # If true all hosts dump HLO, false only jax.process_index()==0 # All hosts should have identical HLO for SPMD programs, however we have encountered some bugs # where this is not the case and it is helpful to compare HLO across hosts. -dump_jaxpr: False +dump_jaxpr: false dump_jaxpr_local_dir: "/tmp/jaxpr_dump/" -dump_jaxpr_delete_local_after: True +dump_jaxpr_delete_local_after: true dump_jaxpr_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/jaxpr_dump # When dropout is false the model is a deterministic function of the # data_shuffle_seed and init_weights_seed (i.e. reproducible losses) -enable_dropout: True -enable_data_shuffling: True +enable_dropout: true +enable_data_shuffling: true data_shuffle_seed: 0 init_weights_seed: 0 # DiLoCo params. -enable_diloco: False +enable_diloco: false diloco_sync_period: 36 diloco_outer_lr: 0.3 diloco_outer_momentum: 0.9 @@ -855,9 +855,9 @@ gradient_accumulation_steps: 1 opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon" -# If True, skip the training step when loss or gradient spike is detected +# If true, skip the training step when loss or gradient spike is detected # No updates for both weights and momentums (if applies) -skip_step_on_spikes: False +skip_step_on_spikes: false # The rolling interval to calculate the mean and standard deviation skip_step_interval: 128 # The scaling factor to determine if a spike occurred @@ -889,19 +889,19 @@ muon_weight_decay: 0 # Strength of the weight decay regularization. This is mult muon_consistent_rms: None # If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2). # Stack trace parameters -collect_stack_trace: False -stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False. +collect_stack_trace: false +stack_trace_to_cloud: false # Uploads to cloud logging if true, else to the console if false. stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds. # Use iota operator in Embed -use_iota_embed: False +use_iota_embed: false # use positional embedding -use_untrainable_positional_embedding: False +use_untrainable_positional_embedding: false trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size # RoPE parameters rope_type: "default" # one of "default", "llama3.1" or "yarn" rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more) -rope_use_scale: True # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more) +rope_use_scale: true # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more) rope_min_timescale: 1 rope_max_timescale: 10_000 # Timescale For global Attention local_rope_max_timescale: -1 # If positive used for local window Attention, otherwise `rope_max_timescale` is used for both local and global @@ -916,9 +916,9 @@ rope_factor: 40 beta_fast: 32 beta_slow: 1 mscale: 1.0 -rope_interleave: True # RoPE with sin/cos interleaved vs concatenated -rope_truncate: True # Floor lower bound and ceil upper bound for correction range -rope_attention_scaling: False # Scale the rotary embedding output +rope_interleave: true # RoPE with sin/cos interleaved vs concatenated +rope_truncate: true # Floor lower bound and ceil upper bound for correction range +rope_attention_scaling: false # Scale the rotary embedding output # Ahead of time Compilation (aka AOT) # Only set these arguments if you are running train_compile or loading a compiled train step. @@ -934,29 +934,29 @@ decode_sampling_temperature: 1. eval_interval: -1 # the specific number of train step between eval_step eval_steps: -1 # run this number of steps for eval, recommend setting this to prevent error due to running out of evel data target_eval_loss: 0. # early stop once reaching target eval_loss -abort_on_nan_loss: True # Check for NaN and abort if found in training loss -abort_on_inf_loss: True # Check for Inf and abort if found in training loss +abort_on_nan_loss: true # Check for NaN and abort if found in training loss +abort_on_inf_loss: true # Check for Inf and abort if found in training loss # Goodput parameters -enable_goodput_recording: False -monitor_goodput: False +enable_goodput_recording: false +monitor_goodput: false goodput_upload_interval_seconds: 30 -enable_pathways_goodput: False -monitor_step_time_deviation: True +enable_pathways_goodput: false +monitor_step_time_deviation: true step_deviation_interval_seconds: 30 -enable_gcp_goodput_metrics: True -enable_gcp_step_deviation_metrics: True +enable_gcp_goodput_metrics: true +enable_gcp_step_deviation_metrics: true # GCP workload monitoring -report_heartbeat_metric_for_gcp_monitoring: False +report_heartbeat_metric_for_gcp_monitoring: false heartbeat_reporting_interval_in_seconds: 5 -report_performance_metric_for_gcp_monitoring: False +report_performance_metric_for_gcp_monitoring: false -enable_tensorboard: True +enable_tensorboard: true # Vertex AI Tensorboard Configurations - https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/use_vertex_ai_tensorboard.md -# Set to True for GCE, False if running via XPK -use_vertex_tensorboard: False +# Set to true for GCE, false if running via XPK +use_vertex_tensorboard: false # Project to create Vertex AI Tensorboard in for GCE, blank if project is set using 'gcloud config set project' # Set this to blank if running via XPK vertex_tensorboard_project: "" @@ -964,8 +964,8 @@ vertex_tensorboard_project: "" # Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions vertex_tensorboard_region: "" -# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance. -max_checkify: False +# If set to true, MaxText will perform extra checks using jax.checkify. Note that this will effect performance. +max_checkify: false # Inference inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" @@ -977,15 +977,15 @@ inference_metadata_file: "" # path to a json file inference_server: "MaxtextInterleavedServer" # inference server to start prefill_slice: "v5e-16" # slice to use for prefill in disaggregation mode generate_slice: "v5e-16" # slice to use for generatation in disaggregation mode -inference_benchmark_test: False -enable_model_warmup: False -enable_llm_inference_pool: False # Bool to launch inference server for llm_inference_gateway with their specified APIs -multi_sampling: False -return_log_prob: False +inference_benchmark_test: false +enable_model_warmup: false +enable_llm_inference_pool: false # Bool to launch inference server for llm_inference_gateway with their specified APIs +multi_sampling: false +return_log_prob: false # Stack prefill cache across the layer to reduce the # Python layer latency. -stack_prefill_result_cache: False +stack_prefill_result_cache: false # KV Cache layout control # Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV @@ -998,42 +998,42 @@ ar_cache_axis_order: "1,2,0,3" # Currently only support compute layout: 0,1,2,3 and 0,2,1,3 compute_axis_order: "0,1,2,3" -reshape_q: False +reshape_q: false # Maxengine Metrics prometheus_port: 0 # Maxengine server -enable_jax_profiler: False +enable_jax_profiler: false jax_profiler_port: 9999 # TPU power trace level for xprof. 0:POWER_TRACE_NONE, 1:POWER_TRACE_NORMAL, or 2:POWER_TRACE_SPI xprof_tpu_power_trace_level: 0 -xprof_e2e_enable_fw_throttle_event: False -xprof_e2e_enable_fw_power_level_event: False -xprof_e2e_enable_fw_thermal_event: False -profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing. +xprof_e2e_enable_fw_throttle_event: false +xprof_e2e_enable_fw_power_level_event: false +xprof_e2e_enable_fw_thermal_event: false +profile_power_events: false # Set to true to enable TPU-specific power/thermal profiling events. Defaults to false to avoid breaking GPU xplane tracing. -log_config: True # Prints the config (after defaults have been set by pyconfig logic) -debug_sharding: False # Prints model weights sharding info +log_config: true # Prints the config (after defaults have been set by pyconfig logic) +debug_sharding: false # Prints model weights sharding info # Checkpoint Structured logging -enable_checkpoint_cloud_logger: False +enable_checkpoint_cloud_logger: false # Single-controller -enable_single_controller: False +enable_single_controller: false custom_mesh: "" # Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8'] # Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html -allow_split_physical_axes: False +allow_split_physical_axes: false # Apply transformations to the mesh to optimize for TPU v6e -optimize_mesh_for_tpu_v6e: False +optimize_mesh_for_tpu_v6e: false -shardy: True # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) +shardy: true # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026) -remove_size_one_mesh_axis_from_type: True # Whether to remove size one mesh axis from type through jax.config. +remove_size_one_mesh_axis_from_type: true # Whether to remove size one mesh axis from type through jax.config. -use_ragged_attention: False +use_ragged_attention: false ragged_block_size: 256 ### Splash attention block sizes @@ -1047,7 +1047,7 @@ sa_block_kv_dkv: 512 sa_block_kv_dkv_compute: 512 sa_block_q_dq: 512 sa_block_kv_dq: 512 -sa_use_fused_bwd_kernel: False +sa_use_fused_bwd_kernel: false sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" @@ -1055,9 +1055,9 @@ use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward) cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward) dq_reduction_steps: 0 #the number of reduction steps. For now, only 3 or all the kv steps are supported. -use_splash_scheduler: False # to use tokamax splash attention scheduler. +use_splash_scheduler: false # to use tokamax splash attention scheduler. ### Determine if we want to use load balance for context parallelism -context_parallel_load_balance: True +context_parallel_load_balance: true context_parallel_strategy: "all_gather" # "all_gather" or "ring" context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped" @@ -1076,21 +1076,21 @@ pagedattn_head_dim_alignment: 128 # Chunked Prefill Parameters prefill_chunk_size: 256 -use_chunked_prefill: False +use_chunked_prefill: false # Prefix Caching parameters in jetstream -enable_prefix_caching: False +enable_prefix_caching: false prefix_caching_hbm_byte: 10_000_000_000 # 10 GB prefix_caching_dram_byte: 100_000_000_000 # 100 GB # This is a temporary flag that will be removed soon after the fix lands in TE -enable_padding_causal_mask: True +enable_padding_causal_mask: true # Llama4-specific # Whether to apply Query/Key normalization. # NOTE: non-Llama4 models use RMSNorm before RoPE # whereas Llama4 models use L2Norm after RoPE -use_qk_norm: False +use_qk_norm: false # Every `X` layers will NOT use RoPE nope_layer_interval: -1 # Every `X` layers is MoE layer @@ -1098,13 +1098,13 @@ interleave_moe_layer_step: 1 # dynamically scale the attention temperature for each query token based on sequence length # Recommended for long sequences (e.g., >32k tokens) to maintain stable output results # See (https://arxiv.org/abs/2501.19399) for more details -temperature_tuning: False +temperature_tuning: false # Multimodal flags -use_multimodal: False -use_audio: False -freeze_vision_encoder_params: True -freeze_audio_encoder_params: True +use_multimodal: false +use_audio: false +freeze_vision_encoder_params: true +freeze_audio_encoder_params: true dtype_mm: "float32" # Data type for multimodal model's vision encoder remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision encoder. Check `remat_policy` for options. image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config @@ -1114,7 +1114,7 @@ audio_path: "" # Local audio path used for decoding, can be multiple paths separ image_placeholder: "<|image|>" video_placeholder: "<|video|>" audio_placeholder: "<|audio|>" -use_audio_in_video: False +use_audio_in_video: false posemb_type_for_vit: "learn" # max_num_images_per_example only applies for training when your image column is a list of images. # -1 means no limit, and will pad to the max possible number of images determined by sequence length. @@ -1155,7 +1155,7 @@ activation_dropout_for_audio: 0.0 activation_function_for_audio: "gelu" num_mel_bins_for_audio: 128 max_source_positions_for_audio: 1500 -scale_embedding_for_audio: True +scale_embedding_for_audio: true n_window_for_audio: 50 n_window_infer_for_audio: 800 conv_chunksize_for_audio: 500 @@ -1174,9 +1174,9 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: False -pure_nnx_decoder: False -pure_nnx: False +enable_nnx: false +pure_nnx_decoder: false +pure_nnx: false ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net @@ -1192,7 +1192,7 @@ gdn_num_value_heads: 32 # Chunk size for the parallel scan algorithm in the Gated Delta Net. gdn_chunk_size: 64 # Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel. -use_qk_norm_in_gdn: True +use_qk_norm_in_gdn: true # The ratio of dimension to apply ROPE on partial_rotary_factor: 1.0 @@ -1206,11 +1206,11 @@ use_jax_splash: false # Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. -# This can be used to override specific settings without modifying the original config file. +# This can be used to override specific settings without modifying the original config file. vllm_hf_overrides: {} # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} -# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] +# When use_jax_splash=true, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] force_q_layout: false ################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ################################## @@ -1221,7 +1221,7 @@ sinkhorn_iterations: 20 ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. -# Example: [1, 4] attaches to the 2nd and 5th layer. +# Example: [1, 4] attaches to the 2nd and 5th layer. engram_layers: [] # The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams. engram_max_ngram_size: 3 diff --git a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml index 75fa46a8be..a7f5e37281 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in +# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in # components except core dMoE part (between EP all2all). mesh_axes: ['data', 'stage', 'fsdp', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'expert']] diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 6853fd09e3..dc98923c9c 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This logical rule is designed to optimize pipeline parallelism for large-scale jobs. -# Key changes include removing expert weight sharding on the `q_lora` dimension, which -# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when -# EP x FSDP > 512. +# This logical rule is designed to optimize pipeline parallelism for large-scale jobs. +# Key changes include removing expert weight sharding on the `q_lora` dimension, which +# is relatively small (e.g., 512 for DeepSeek), and limiting sharding strategies when +# EP x FSDP > 512. +# +# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a +# data parallel (DP) domain externally, making the `data` axis a necessary reference; +# second, it may be required for DCN communication. # -# The `data` axis is preserved for two reasons: first, the pipeline stage acts as a -# data parallel (DP) domain externally, making the `data` axis a necessary reference; -# second, it may be required for DCN communication. -# # The `context` axis is used for supporting fractional per device batch size # -# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or -# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to +# Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or +# `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to # store prefetched weights. mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']] diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index 1d3a5e4cd0..fe05b6c46b 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy +# This rule only uses FSDP. Pure FSDP is the go-to sharding strategy # for small-scale training and this rule simplifies the overall configuration. mesh_axes: ['fsdp'] data_sharding: [['fsdp']] diff --git a/src/maxtext/configs/decoupled_base_test.yml b/src/maxtext/configs/decoupled_base_test.yml index ade834239e..8494011ccc 100644 --- a/src/maxtext/configs/decoupled_base_test.yml +++ b/src/maxtext/configs/decoupled_base_test.yml @@ -1,5 +1,5 @@ # Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml. -# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable +# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable # optional cloud features. base_config: base.yml diff --git a/src/maxtext/configs/gpu/models/llama2_70b.yml b/src/maxtext/configs/gpu/models/llama2_70b.yml index 3f5a455c8a..f0d4fc6a8f 100644 --- a/src/maxtext/configs/gpu/models/llama2_70b.yml +++ b/src/maxtext/configs/gpu/models/llama2_70b.yml @@ -5,14 +5,14 @@ run_name: "gpu_train_test" hardware: "gpu" steps: 30 model_name: "llama2-70b" -enable_checkpointing: False +enable_checkpointing: false attention: "cudnn_flash_te" remat_policy: "full" -use_iota_embed: True -scan_layers: True +use_iota_embed: true +scan_layers: true dataset_type: "synthetic" -async_checkpointing: False -logits_dot_in_fp32: False +async_checkpointing: false +logits_dot_in_fp32: false per_device_batch_size: 6 max_target_length: 4096 diff --git a/src/maxtext/configs/gpu/models/llama2_7b.yml b/src/maxtext/configs/gpu/models/llama2_7b.yml index afa401e20c..49079b1eba 100644 --- a/src/maxtext/configs/gpu/models/llama2_7b.yml +++ b/src/maxtext/configs/gpu/models/llama2_7b.yml @@ -6,11 +6,11 @@ steps: 30 per_device_batch_size: 4 max_target_length: 4096 model_name: "llama2-7b" -enable_checkpointing: False +enable_checkpointing: false attention: "cudnn_flash_te" remat_policy: "minimal_with_context" -use_iota_embed: True -scan_layers: False +use_iota_embed: true +scan_layers: false dataset_type: "synthetic" -async_checkpointing: False +async_checkpointing: false max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/llama3.1_405b.yml b/src/maxtext/configs/gpu/models/llama3.1_405b.yml index e2f853fdc9..3acd8a0398 100644 --- a/src/maxtext/configs/gpu/models/llama3.1_405b.yml +++ b/src/maxtext/configs/gpu/models/llama3.1_405b.yml @@ -4,14 +4,14 @@ run_name: "gpu_train_test" hardware: "gpu" steps: 10 model_name: "llama3.1-405b" -enable_checkpointing: False +enable_checkpointing: false #attention: "cudnn_flash_te" remat_policy: "full" -use_iota_embed: True -scan_layers: True +use_iota_embed: true +scan_layers: true dataset_type: "synthetic" -async_checkpointing: False -logits_dot_in_fp32: False +async_checkpointing: false +logits_dot_in_fp32: false per_device_batch_size: 1.0 max_target_length: 4096 max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/llama3_70b.yml b/src/maxtext/configs/gpu/models/llama3_70b.yml index 7deb603bea..bcbd275b7d 100644 --- a/src/maxtext/configs/gpu/models/llama3_70b.yml +++ b/src/maxtext/configs/gpu/models/llama3_70b.yml @@ -24,8 +24,8 @@ per_device_batch_size: 4 max_target_length: 8192 attention: "cudnn_flash_te" remat_policy: "full" -use_iota_embed: True +use_iota_embed: true dataset_type: "synthetic" reuse_example_batch: 1 -enable_checkpointing: False +enable_checkpointing: false max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/llama3_8b.yml b/src/maxtext/configs/gpu/models/llama3_8b.yml index bfd6ffb71e..ebab7cda57 100644 --- a/src/maxtext/configs/gpu/models/llama3_8b.yml +++ b/src/maxtext/configs/gpu/models/llama3_8b.yml @@ -24,8 +24,8 @@ per_device_batch_size: 12 max_target_length: 8192 attention: "cudnn_flash_te" remat_policy: "minimal_with_context" -use_iota_embed: True +use_iota_embed: true dataset_type: "synthetic" reuse_example_batch: 1 -enable_checkpointing: False +enable_checkpointing: false max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/mixtral_8x1b.yml b/src/maxtext/configs/gpu/models/mixtral_8x1b.yml index 73c789ba09..4c1f6eae30 100644 --- a/src/maxtext/configs/gpu/models/mixtral_8x1b.yml +++ b/src/maxtext/configs/gpu/models/mixtral_8x1b.yml @@ -27,10 +27,10 @@ per_device_batch_size: 8 max_target_length: 4096 attention: "cudnn_flash_te" remat_policy: "full" -use_iota_embed: True +use_iota_embed: true dataset_type: "synthetic" reuse_example_batch: 1 -enable_checkpointing: False -megablox: False -sparse_matmul: False +enable_checkpointing: false +megablox: false +sparse_matmul: false max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/mixtral_8x2b.yml b/src/maxtext/configs/gpu/models/mixtral_8x2b.yml index 22481969c9..206cde12bb 100644 --- a/src/maxtext/configs/gpu/models/mixtral_8x2b.yml +++ b/src/maxtext/configs/gpu/models/mixtral_8x2b.yml @@ -27,10 +27,10 @@ per_device_batch_size: 8 max_target_length: 4096 attention: "cudnn_flash_te" remat_policy: "full" -use_iota_embed: True +use_iota_embed: true dataset_type: "synthetic" reuse_example_batch: 1 -enable_checkpointing: False -megablox: False -sparse_matmul: False +enable_checkpointing: false +megablox: false +sparse_matmul: false max_segments_per_seq: 32 diff --git a/src/maxtext/configs/gpu/models/mixtral_8x7b.yml b/src/maxtext/configs/gpu/models/mixtral_8x7b.yml index 5fa58f066f..f0fbc6968f 100644 --- a/src/maxtext/configs/gpu/models/mixtral_8x7b.yml +++ b/src/maxtext/configs/gpu/models/mixtral_8x7b.yml @@ -24,12 +24,12 @@ per_device_batch_size: 12 max_target_length: 4096 attention: "cudnn_flash_te" remat_policy: "minimal_with_context" -use_iota_embed: True +use_iota_embed: true dataset_type: "synthetic" reuse_example_batch: 1 -enable_checkpointing: False -megablox: False -scan_layers: False +enable_checkpointing: false +megablox: false +scan_layers: false tokenizer_path: "/deps/src/maxtext/assets/tokenizers/tokenizer.mistral-v1" profiler: "nsys" capacity_factor: 1.0 diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 7a263cc282..a1b9835e28 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -62,4 +62,4 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] \ No newline at end of file +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] diff --git a/src/maxtext/configs/inference/inference_jetstream.yml b/src/maxtext/configs/inference/inference_jetstream.yml index 1cd3683ea6..1b92796c6c 100644 --- a/src/maxtext/configs/inference/inference_jetstream.yml +++ b/src/maxtext/configs/inference/inference_jetstream.yml @@ -1,6 +1,6 @@ base_config: "base.yml" -enable_jax_profiler: False +enable_jax_profiler: false jax_profiler_port: 9999 -enable_model_warmup: False \ No newline at end of file +enable_model_warmup: false diff --git a/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml b/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml index aedbc64853..2b83fadc56 100644 --- a/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml +++ b/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml @@ -3,10 +3,10 @@ base_config: "inference/inference_jetstream.yml" model_name: "llama3.1-405b" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true tokenizer_path: "assets/tokenizer_llama3.tiktoken" # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true inference_server: "ExperimentalMaxtextDisaggregatedServer" diff --git a/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml b/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml index 121efe248e..dba4bc03ce 100644 --- a/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml @@ -7,9 +7,9 @@ base_config: "inference/inference_jetstream.yml" model_name: "llama2-70b" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml b/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml index b91bb85fb3..b71b7990f1 100644 --- a/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml @@ -8,10 +8,10 @@ base_config: "inference/inference_jetstream.yml" model_name: "llama3.1-405b" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true tokenizer_path: "assets/tokenizer_llama3.tiktoken" # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml b/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml index b3ca2d1465..525d30e30c 100644 --- a/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml @@ -8,9 +8,9 @@ model_name: "llama3-70b" tokenizer_path: "assets/tokenizer_llama3.tiktoken" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index c21df6c70c..3f9e4f5290 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -17,15 +17,15 @@ attention: "vllm_rpa" model_call_mode: "inference" # NNX required for vLLM integration -enable_nnx: True +enable_nnx: true # Avoid re-initializing JAX distributed system when using vLLM -skip_jax_distributed_system: True +skip_jax_distributed_system: true # Scanned layers are not supported with vLLM integration -scan_layers: False +scan_layers: false # Set weight dtype to bfloat16 as is done in vLLM weight_dtype: bfloat16 # Allow model config to be overridden by CLI kwargs -override_model_config: True +override_model_config: true # -------------- Logical Axis Rules -------------- diff --git a/src/maxtext/configs/models/deepseek-custom.yml b/src/maxtext/configs/models/deepseek-custom.yml index 5e10d50ba2..bab9b20104 100644 --- a/src/maxtext/configs/models/deepseek-custom.yml +++ b/src/maxtext/configs/models/deepseek-custom.yml @@ -16,7 +16,7 @@ # Included modules: DeepSeek Sparse Attention, Engram, mHC # Example command: -# python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=demo model_name=deepseek-custom scan_layers=True attention=flash use_tokamax_splash=True enable_checkpointing=false async_checkpointing=false dataset_type=synthetic steps=5 per_device_batch_size=4 max_target_length=1024 dtype=bfloat16 weight_dtype=bfloat16 tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3.2 hf_access_token=${HF_TOKEN} +# python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=demo model_name=deepseek-custom scan_layers=true attention=flash use_tokamax_splash=true enable_checkpointing=false async_checkpointing=false dataset_type=synthetic steps=5 per_device_batch_size=4 max_target_length=1024 dtype=bfloat16 weight_dtype=bfloat16 tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3.2 hf_access_token=${HF_TOKEN} base_emb_dim: 1024 # Reduced from 7168 base_num_query_heads: 16 # Reduced from 128 @@ -27,15 +27,15 @@ base_num_decoder_layers: 6 # Reduced from 61 first_num_dense_layers: 1 # Reduced from 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 16 # Reduced from 256 num_experts_per_tok: 2 # Reduced from 8 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" @@ -52,11 +52,11 @@ max_position_embeddings: 4096 # Reduced for local testing original_max_position_embeddings: 4096 rope_factor: 1 beta_fast: 32 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false # Indexer for DeepSeek Sparse Attention -use_indexer: True +use_indexer: true indexer_n_heads: 16 # Reduced from 64 indexer_head_dim: 64 # Reduced from 128 indexer_topk: 256 # Reduced from 2048 diff --git a/src/maxtext/configs/models/deepseek2-16b.yml b/src/maxtext/configs/models/deepseek2-16b.yml index c4e0247cfe..38eb012fb1 100644 --- a/src/maxtext/configs/models/deepseek2-16b.yml +++ b/src/maxtext/configs/models/deepseek2-16b.yml @@ -23,15 +23,15 @@ base_num_decoder_layers: 27 first_num_dense_layers: 1 mlp_activations: ["silu","linear"] vocab_size: 102400 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 64 num_experts_per_tok: 6 shared_experts: 2 routed_scaling_factor: 1.0 routed_score_func: "softmax" -routed_bias: False +routed_bias: false decoder_block: "deepseek" # MLA attention_type: "mla" @@ -48,6 +48,6 @@ original_max_position_embeddings: 4096 rope_factor: 40 beta_fast: 32 mscale: 0.707 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false diff --git a/src/maxtext/configs/models/deepseek2-236b.yml b/src/maxtext/configs/models/deepseek2-236b.yml index c3c9880b86..6d839f7477 100644 --- a/src/maxtext/configs/models/deepseek2-236b.yml +++ b/src/maxtext/configs/models/deepseek2-236b.yml @@ -24,15 +24,15 @@ base_num_decoder_layers: 60 first_num_dense_layers: 1 mlp_activations: ["silu","linear"] vocab_size: 102400 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 160 num_experts_per_tok: 6 shared_experts: 2 routed_scaling_factor: 16.0 routed_score_func: "softmax" -routed_bias: False +routed_bias: false decoder_block: "deepseek" # MLA attention_type: "mla" @@ -41,7 +41,7 @@ kv_lora_rank: 512 qk_nope_head_dim: 128 qk_rope_head_dim: 64 v_head_dim: 128 -# RoPE +# RoPE rope_type: "yarn" rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 max_position_embeddings: 163840 @@ -49,6 +49,6 @@ original_max_position_embeddings: 4096 rope_factor: 40 beta_fast: 32 mscale: 0.707 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false diff --git a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml index fda64a877a..a24c5af9df 100644 --- a/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml +++ b/src/maxtext/configs/models/deepseek3-671b-batchsplit.yml @@ -14,7 +14,7 @@ # model config for DeepSeek V3 - 671B that uses batch split schedule -# For DeepSeek default device-limited routing, +# For DeepSeek default device-limited routing, # please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. base_emb_dim: 7168 @@ -26,8 +26,8 @@ base_num_decoder_layers: 61 first_num_dense_layers: 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 256 num_experts_per_tok: 8 @@ -36,7 +36,7 @@ topk_routing_group: 4 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" @@ -53,14 +53,14 @@ max_position_embeddings: 163840 original_max_position_embeddings: 4096 rope_factor: 40 beta_fast: 32 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false -use_batch_split_schedule: True +use_batch_split_schedule: true shard_mode: "explicit" -remove_size_one_mesh_axis_from_type: False -override_logical_axis_rules: True +remove_size_one_mesh_axis_from_type: false +override_logical_axis_rules: true mesh_axes: ['data', 'fsdp', 'expert', 'context'] data_sharding: [['data', 'fsdp', 'expert', 'context']] logical_axis_rules: [ diff --git a/src/maxtext/configs/models/deepseek3-671b.yml b/src/maxtext/configs/models/deepseek3-671b.yml index 18e566cf57..347156fcdb 100644 --- a/src/maxtext/configs/models/deepseek3-671b.yml +++ b/src/maxtext/configs/models/deepseek3-671b.yml @@ -14,7 +14,7 @@ # model config for DeepSeek V3 - 671B -# For DeepSeek default device-limited routing, +# For DeepSeek default device-limited routing, # please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments. base_emb_dim: 7168 @@ -26,15 +26,15 @@ base_num_decoder_layers: 61 first_num_dense_layers: 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 256 num_experts_per_tok: 8 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" @@ -51,6 +51,6 @@ max_position_embeddings: 163840 original_max_position_embeddings: 4096 rope_factor: 40 beta_fast: 32 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false diff --git a/src/maxtext/configs/models/deepseek3-test.yml b/src/maxtext/configs/models/deepseek3-test.yml index 65b0ffc113..aa6694082b 100644 --- a/src/maxtext/configs/models/deepseek3-test.yml +++ b/src/maxtext/configs/models/deepseek3-test.yml @@ -25,15 +25,15 @@ base_num_decoder_layers: 61 first_num_dense_layers: 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 256 num_experts_per_tok: 8 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" diff --git a/src/maxtext/configs/models/deepseek3-tiny.yml b/src/maxtext/configs/models/deepseek3-tiny.yml index 4448df0693..6b3d3a5711 100644 --- a/src/maxtext/configs/models/deepseek3-tiny.yml +++ b/src/maxtext/configs/models/deepseek3-tiny.yml @@ -23,15 +23,15 @@ base_num_decoder_layers: 61 first_num_dense_layers: 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 16 num_experts_per_tok: 8 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" diff --git a/src/maxtext/configs/models/deepseek3.2-671b.yml b/src/maxtext/configs/models/deepseek3.2-671b.yml index 69aacc9517..2addc49357 100644 --- a/src/maxtext/configs/models/deepseek3.2-671b.yml +++ b/src/maxtext/configs/models/deepseek3.2-671b.yml @@ -24,15 +24,15 @@ base_num_decoder_layers: 61 first_num_dense_layers: 3 mlp_activations: ["silu","linear"] vocab_size: 129280 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 256 num_experts_per_tok: 8 shared_experts: 1 routed_scaling_factor: 2.5 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" @@ -49,11 +49,11 @@ max_position_embeddings: 163840 original_max_position_embeddings: 4096 rope_factor: 40 beta_fast: 32 -rope_interleave: True -rope_truncate: True -rope_attention_scaling: False +rope_interleave: true +rope_truncate: true +rope_attention_scaling: false # Indexer for DeepSeek Sparse Attention -use_indexer: True +use_indexer: true indexer_n_heads: 64 indexer_head_dim: 128 indexer_topk: 2048 diff --git a/src/maxtext/configs/models/gemma-2b.yml b/src/maxtext/configs/models/gemma-2b.yml index 4aa8c155e4..c268866388 100644 --- a/src/maxtext/configs/models/gemma-2b.yml +++ b/src/maxtext/configs/models/gemma-2b.yml @@ -24,4 +24,4 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True \ No newline at end of file +logits_via_embedding: true diff --git a/src/maxtext/configs/models/gemma-7b.yml b/src/maxtext/configs/models/gemma-7b.yml index 21df36235f..5852c44bb3 100644 --- a/src/maxtext/configs/models/gemma-7b.yml +++ b/src/maxtext/configs/models/gemma-7b.yml @@ -24,4 +24,4 @@ mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True \ No newline at end of file +logits_via_embedding: true diff --git a/src/maxtext/configs/models/gemma2-27b.yml b/src/maxtext/configs/models/gemma2-27b.yml index 76ca8883a5..3fec354e23 100644 --- a/src/maxtext/configs/models/gemma2-27b.yml +++ b/src/maxtext/configs/models/gemma2-27b.yml @@ -18,15 +18,15 @@ base_emb_dim: 4608 base_num_query_heads: 32 base_num_kv_heads: 16 base_mlp_dim: 36864 -base_num_decoder_layers: 23 # half of the real number of layers because we merge [local_attention, global_attention] into one layer +base_num_decoder_layers: 23 # half of the real number of layers because we merge [local_attention, global_attention] into one layer head_dim: 128 mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma2" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True +logits_via_embedding: true final_logits_soft_cap: 30.0 attn_logits_soft_cap: 50.0 sliding_window_size: 4096 -use_post_attn_norm: True -use_post_ffw_norm: True +use_post_attn_norm: true +use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma2-2b.yml b/src/maxtext/configs/models/gemma2-2b.yml index cb8ce5a865..23ce70a708 100644 --- a/src/maxtext/configs/models/gemma2-2b.yml +++ b/src/maxtext/configs/models/gemma2-2b.yml @@ -18,15 +18,15 @@ base_emb_dim: 2304 base_num_query_heads: 8 base_num_kv_heads: 4 base_mlp_dim: 9216 -base_num_decoder_layers: 13 # half of the real number of layers because we merge [local_attention, global_attention] into one layer +base_num_decoder_layers: 13 # half of the real number of layers because we merge [local_attention, global_attention] into one layer head_dim: 256 mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma2" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True +logits_via_embedding: true final_logits_soft_cap: 30.0 attn_logits_soft_cap: 50.0 sliding_window_size: 4096 -use_post_attn_norm: True -use_post_ffw_norm: True +use_post_attn_norm: true +use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma2-9b.yml b/src/maxtext/configs/models/gemma2-9b.yml index da75eaca22..6cf082a6e9 100644 --- a/src/maxtext/configs/models/gemma2-9b.yml +++ b/src/maxtext/configs/models/gemma2-9b.yml @@ -18,15 +18,15 @@ base_emb_dim: 3584 base_num_query_heads: 16 base_num_kv_heads: 8 base_mlp_dim: 14336 -base_num_decoder_layers: 21 # half of the real number of layers because we merge [local_attention, global_attention] into one layer +base_num_decoder_layers: 21 # half of the real number of layers because we merge [local_attention, global_attention] into one layer head_dim: 256 mlp_activations: ["gelu","linear"] vocab_size: 256128 decoder_block: "gemma2" normalization_layer_epsilon: 1.e-06 -logits_via_embedding: True +logits_via_embedding: true final_logits_soft_cap: 30.0 attn_logits_soft_cap: 50.0 sliding_window_size: 4096 -use_post_attn_norm: True -use_post_ffw_norm: True +use_post_attn_norm: true +use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma3-12b.yml b/src/maxtext/configs/models/gemma3-12b.yml index c2b9aae6a5..a15e49fe86 100644 --- a/src/maxtext/configs/models/gemma3-12b.yml +++ b/src/maxtext/configs/models/gemma3-12b.yml @@ -24,7 +24,7 @@ mlp_activations: ["gelu","linear"] vocab_size: 262_144 decoder_block: "gemma3" normalization_layer_epsilon: 1e-6 -logits_via_embedding: True +logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma3-27b.yml b/src/maxtext/configs/models/gemma3-27b.yml index dbff83e8f4..5d3b70a3a9 100644 --- a/src/maxtext/configs/models/gemma3-27b.yml +++ b/src/maxtext/configs/models/gemma3-27b.yml @@ -24,7 +24,7 @@ mlp_activations: ["gelu","linear"] vocab_size: 262_144 decoder_block: "gemma3" normalization_layer_epsilon: 1e-6 -logits_via_embedding: True +logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma3-4b.yml b/src/maxtext/configs/models/gemma3-4b.yml index f22dadef7e..ef01a3e762 100644 --- a/src/maxtext/configs/models/gemma3-4b.yml +++ b/src/maxtext/configs/models/gemma3-4b.yml @@ -24,7 +24,7 @@ mlp_activations: ["gelu","linear"] vocab_size: 262_144 decoder_block: "gemma3" normalization_layer_epsilon: 1e-6 -logits_via_embedding: True +logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma4-26b.yml b/src/maxtext/configs/models/gemma4-26b.yml index d33c0cd90f..d32149bccb 100644 --- a/src/maxtext/configs/models/gemma4-26b.yml +++ b/src/maxtext/configs/models/gemma4-26b.yml @@ -27,7 +27,7 @@ global_num_kv_heads: 2 vocab_size: 262144 decoder_block: "gemma4" normalization_layer_epsilon: 1e-6 -logits_via_embedding: True +logits_via_embedding: true sliding_window_size: 1024 use_post_attn_norm: true use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gemma4-31b.yml b/src/maxtext/configs/models/gemma4-31b.yml index e055be25c5..9dec302bc1 100644 --- a/src/maxtext/configs/models/gemma4-31b.yml +++ b/src/maxtext/configs/models/gemma4-31b.yml @@ -26,7 +26,7 @@ global_num_kv_heads: 4 vocab_size: 262144 decoder_block: "gemma4" normalization_layer_epsilon: 1e-6 -logits_via_embedding: True +logits_via_embedding: true use_post_attn_norm: true use_post_ffw_norm: true diff --git a/src/maxtext/configs/models/gpt-oss-120b.yml b/src/maxtext/configs/models/gpt-oss-120b.yml index f7db3f681e..e6a624c7ff 100644 --- a/src/maxtext/configs/models/gpt-oss-120b.yml +++ b/src/maxtext/configs/models/gpt-oss-120b.yml @@ -23,8 +23,8 @@ base_num_query_heads: 64 base_num_kv_heads: 8 head_dim: 64 sliding_window_size: 128 -attention_bias: True -attention_sink: True +attention_bias: true +attention_sink: true # RoPE rope_type: "yarn" @@ -34,17 +34,17 @@ original_max_position_embeddings: 4096 rope_factor: 32 beta_fast: 32 beta_slow: 1 -rope_interleave: False -rope_truncate: False -rope_attention_scaling: True +rope_interleave: false +rope_truncate: false +rope_attention_scaling: true # MLP base_mlp_dim: 2880 base_moe_mlp_dim: 2880 mlp_activations: ["sigmoid","linear"] mlp_activations_limit: 7.0 -routed_bias: True -mlp_bias: True +routed_bias: true +mlp_bias: true num_experts: 128 num_experts_per_tok: 4 @@ -52,7 +52,7 @@ num_experts_per_tok: 4 base_num_decoder_layers: 36 vocab_size: 201088 normalization_layer_epsilon: 1.0e-5 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false decoder_block: "gpt_oss" inhomogeneous_layer_cycle_interval: 2 diff --git a/src/maxtext/configs/models/gpt-oss-20b.yml b/src/maxtext/configs/models/gpt-oss-20b.yml index 65ec386429..313c7cfe67 100644 --- a/src/maxtext/configs/models/gpt-oss-20b.yml +++ b/src/maxtext/configs/models/gpt-oss-20b.yml @@ -23,8 +23,8 @@ base_num_query_heads: 64 base_num_kv_heads: 8 head_dim: 64 sliding_window_size: 128 -attention_bias: True -attention_sink: True +attention_bias: true +attention_sink: true # RoPE rope_type: "yarn" @@ -34,17 +34,17 @@ original_max_position_embeddings: 4096 rope_factor: 32 beta_fast: 32 beta_slow: 1 -rope_interleave: False -rope_truncate: False -rope_attention_scaling: True +rope_interleave: false +rope_truncate: false +rope_attention_scaling: true # MLP base_mlp_dim: 2880 base_moe_mlp_dim: 2880 mlp_activations: ["sigmoid","linear"] mlp_activations_limit: 7.0 -routed_bias: True -mlp_bias: True +routed_bias: true +mlp_bias: true num_experts: 32 num_experts_per_tok: 4 @@ -52,7 +52,7 @@ num_experts_per_tok: 4 base_num_decoder_layers: 24 vocab_size: 201088 normalization_layer_epsilon: 1.0e-5 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false decoder_block: "gpt_oss" inhomogeneous_layer_cycle_interval: 2 diff --git a/src/maxtext/configs/models/gpt3-175b.yml b/src/maxtext/configs/models/gpt3-175b.yml index 5e24cf4268..a7f1e439b2 100644 --- a/src/maxtext/configs/models/gpt3-175b.yml +++ b/src/maxtext/configs/models/gpt3-175b.yml @@ -23,13 +23,13 @@ head_dim: 128 trainable_position_size: 16384 mlp_activations: ["gelu"] vocab_size: 50304 -enable_dropout: False -logits_via_embedding: True -normalize_embedding_logits: False -logits_dot_in_fp32: False +enable_dropout: false +logits_via_embedding: true +normalize_embedding_logits: false +logits_dot_in_fp32: false normalization_layer_epsilon: 1.e-05 -use_iota_embed: True -fused_qkv: True +use_iota_embed: true +fused_qkv: true opt_type: "adam_pax" decoder_block: "gpt3" dataset_path: "gs://mlperf-llm-public2" diff --git a/src/maxtext/configs/models/gpt3-22b.yml b/src/maxtext/configs/models/gpt3-22b.yml index 0e0905442c..5591ddbada 100644 --- a/src/maxtext/configs/models/gpt3-22b.yml +++ b/src/maxtext/configs/models/gpt3-22b.yml @@ -24,13 +24,13 @@ max_target_length: 1024 trainable_position_size: 16384 mlp_activations: ["gelu"] vocab_size: 32768 -enable_dropout: False -logits_via_embedding: True -normalize_embedding_logits: False -logits_dot_in_fp32: False +enable_dropout: false +logits_via_embedding: true +normalize_embedding_logits: false +logits_dot_in_fp32: false normalization_layer_epsilon: 1.e-05 -use_iota_embed: True -fused_qkv: True +use_iota_embed: true +fused_qkv: true opt_type: "adam_pax" decoder_block: "gpt3" gradient_clipping_threshold: 1. diff --git a/src/maxtext/configs/models/gpt3-52k.yml b/src/maxtext/configs/models/gpt3-52k.yml index 5513663f82..39ba72ecc0 100644 --- a/src/maxtext/configs/models/gpt3-52k.yml +++ b/src/maxtext/configs/models/gpt3-52k.yml @@ -23,13 +23,13 @@ head_dim: 8 trainable_position_size: 2048 mlp_activations: ["gelu"] vocab_size: 1024 -enable_dropout: False -logits_via_embedding: True -normalize_embedding_logits: False -logits_dot_in_fp32: False +enable_dropout: false +logits_via_embedding: true +normalize_embedding_logits: false +logits_dot_in_fp32: false normalization_layer_epsilon: 1.e-05 -use_iota_embed: True -fused_qkv: True +use_iota_embed: true +fused_qkv: true opt_type: "adam_pax" decoder_block: "gpt3" gradient_clipping_threshold: 1. diff --git a/src/maxtext/configs/models/gpt3-6b.yml b/src/maxtext/configs/models/gpt3-6b.yml index 7ee0766ec2..2ac6fb0f5b 100644 --- a/src/maxtext/configs/models/gpt3-6b.yml +++ b/src/maxtext/configs/models/gpt3-6b.yml @@ -24,13 +24,13 @@ max_target_length: 1024 trainable_position_size: 16384 mlp_activations: ["gelu"] vocab_size: 32768 -enable_dropout: False -logits_via_embedding: True -normalize_embedding_logits: False -logits_dot_in_fp32: False +enable_dropout: false +logits_via_embedding: true +normalize_embedding_logits: false +logits_dot_in_fp32: false normalization_layer_epsilon: 1.e-05 -use_iota_embed: True -fused_qkv: True +use_iota_embed: true +fused_qkv: true opt_type: "adam_pax" decoder_block: "gpt3" gradient_clipping_threshold: 1. diff --git a/src/maxtext/configs/models/kimi-k2-1t.yml b/src/maxtext/configs/models/kimi-k2-1t.yml index f27ce8c946..4e28d1b90c 100644 --- a/src/maxtext/configs/models/kimi-k2-1t.yml +++ b/src/maxtext/configs/models/kimi-k2-1t.yml @@ -25,15 +25,15 @@ base_num_decoder_layers: 61 first_num_dense_layers: 1 mlp_activations: ["silu", "linear"] vocab_size: 163840 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-6 num_experts: 384 num_experts_per_tok: 8 shared_experts: 1 routed_scaling_factor: 2.827 routed_score_func: "sigmoid" -routed_bias: True +routed_bias: true decoder_block: "deepseek" # MLA attention_type: "mla" diff --git a/src/maxtext/configs/models/llama2-13b.yml b/src/maxtext/configs/models/llama2-13b.yml index faf85c841f..1381833b2a 100644 --- a/src/maxtext/configs/models/llama2-13b.yml +++ b/src/maxtext/configs/models/llama2-13b.yml @@ -22,8 +22,8 @@ base_num_decoder_layers: 40 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32000 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" logical_axis_rules: [['norm', 'fsdp']] diff --git a/src/maxtext/configs/models/llama2-70b.yml b/src/maxtext/configs/models/llama2-70b.yml index 67dd87f68f..314ce6907b 100644 --- a/src/maxtext/configs/models/llama2-70b.yml +++ b/src/maxtext/configs/models/llama2-70b.yml @@ -22,7 +22,7 @@ base_num_decoder_layers: 80 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32000 -logits_via_embedding: False +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" logical_axis_rules: [['norm', 'fsdp']] diff --git a/src/maxtext/configs/models/llama2-7b.yml b/src/maxtext/configs/models/llama2-7b.yml index 3ac04ef0f2..0be31ce536 100644 --- a/src/maxtext/configs/models/llama2-7b.yml +++ b/src/maxtext/configs/models/llama2-7b.yml @@ -22,8 +22,8 @@ base_num_decoder_layers: 32 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32000 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 decoder_block: "llama2" logical_axis_rules: [['norm', 'fsdp']] diff --git a/src/maxtext/configs/models/llama3-405b.yml b/src/maxtext/configs/models/llama3-405b.yml index c3f064f693..0d11fdd16c 100644 --- a/src/maxtext/configs/models/llama3-405b.yml +++ b/src/maxtext/configs/models/llama3-405b.yml @@ -23,8 +23,8 @@ base_mlp_dim: 53248 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3-70b.yml b/src/maxtext/configs/models/llama3-70b.yml index 9f9fc7f973..90c4e3a6fb 100644 --- a/src/maxtext/configs/models/llama3-70b.yml +++ b/src/maxtext/configs/models/llama3-70b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 28672 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3-8b.yml b/src/maxtext/configs/models/llama3-8b.yml index 365738e176..777b5d5868 100644 --- a/src/maxtext/configs/models/llama3-8b.yml +++ b/src/maxtext/configs/models/llama3-8b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 14336 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3.1-405b.yml b/src/maxtext/configs/models/llama3.1-405b.yml index 73013c098e..7a9b1b17ad 100644 --- a/src/maxtext/configs/models/llama3.1-405b.yml +++ b/src/maxtext/configs/models/llama3.1-405b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 53248 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3.1-70b.yml b/src/maxtext/configs/models/llama3.1-70b.yml index 429803c751..2d31a4986c 100644 --- a/src/maxtext/configs/models/llama3.1-70b.yml +++ b/src/maxtext/configs/models/llama3.1-70b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 28672 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3.1-8b.yml b/src/maxtext/configs/models/llama3.1-8b.yml index 46e54447fa..1d3344612b 100644 --- a/src/maxtext/configs/models/llama3.1-8b.yml +++ b/src/maxtext/configs/models/llama3.1-8b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 14336 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama3.3-70b.yml b/src/maxtext/configs/models/llama3.3-70b.yml index 5bbfba3b6c..481e767561 100644 --- a/src/maxtext/configs/models/llama3.3-70b.yml +++ b/src/maxtext/configs/models/llama3.3-70b.yml @@ -22,8 +22,8 @@ base_mlp_dim: 28672 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 128256 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 500_000 decoder_block: "llama2" # Uses the same decoder block as llama2 diff --git a/src/maxtext/configs/models/llama4-17b-128e.yml b/src/maxtext/configs/models/llama4-17b-128e.yml index bc919dd1d0..0f0e4a71fe 100644 --- a/src/maxtext/configs/models/llama4-17b-128e.yml +++ b/src/maxtext/configs/models/llama4-17b-128e.yml @@ -15,7 +15,7 @@ decoder_block: "llama4" mlp_activations: ["silu","linear"] -enable_dropout: False +enable_dropout: false tokenizer_type: "huggingface" base_emb_dim: 5120 @@ -28,16 +28,16 @@ vocab_size: 202048 normalization_layer_epsilon: 1e-05 rope_max_timescale: 500000 rope_type: "llama3.1" -rope_use_scale: False +rope_use_scale: false num_experts: 128 shared_experts: 1 num_experts_per_tok: 1 -use_qk_norm: False +use_qk_norm: false nope_layer_interval: 4 # Every fourth layer should NOT use RoPE interleave_moe_layer_step: 2 # Every 2nd layer is MoE layer, and 1st layer is dense layer inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval) -temperature_tuning: True +temperature_tuning: true # Chunk attention is used on all RoPE layers # otherwise, on NoPE layers, use global attention chunk_attn_window_size: 8192 diff --git a/src/maxtext/configs/models/llama4-17b-16e.yml b/src/maxtext/configs/models/llama4-17b-16e.yml index 15c17aa79f..9fc48b1803 100644 --- a/src/maxtext/configs/models/llama4-17b-16e.yml +++ b/src/maxtext/configs/models/llama4-17b-16e.yml @@ -15,8 +15,8 @@ decoder_block: "llama4" mlp_activations: ["silu","linear"] -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false tokenizer_type: "huggingface" base_emb_dim: 5120 @@ -32,12 +32,12 @@ rope_type: "llama3.1" num_experts: 16 shared_experts: 1 num_experts_per_tok: 1 -use_qk_norm: True # Llama4 models apply an L2Norm to the Query and Keys after RoPE +use_qk_norm: true # Llama4 models apply an L2Norm to the Query and Keys after RoPE nope_layer_interval: 4 # Every fourth layer should NOT use RoPE interleave_moe_layer_step: 1 # Every layer is MoE layer inhomogeneous_layer_cycle_interval: 4 # Every four layers the pattern of nope and moe repeats (least common multiple of nope interval and moe interval) -temperature_tuning: True +temperature_tuning: true # Chunk attention is used on all RoPE layers # otherwise, on NoPE layers, use global attention chunk_attn_window_size: 8192 diff --git a/src/maxtext/configs/models/mistral-7b.yml b/src/maxtext/configs/models/mistral-7b.yml index 97f4960f73..07d59fa5af 100644 --- a/src/maxtext/configs/models/mistral-7b.yml +++ b/src/maxtext/configs/models/mistral-7b.yml @@ -22,8 +22,8 @@ base_num_decoder_layers: 32 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32000 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 rope_max_timescale: 1_000_000 decoder_block: "mistral" diff --git a/src/maxtext/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml index 0d040bf48a..b6d1fc71c0 100644 --- a/src/maxtext/configs/models/mixtral-8x22b.yml +++ b/src/maxtext/configs/models/mixtral-8x22b.yml @@ -24,8 +24,8 @@ base_num_decoder_layers: 56 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32768 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 num_experts: 8 num_experts_per_tok: 2 diff --git a/src/maxtext/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml index 91a7ab50bc..9528a667c5 100644 --- a/src/maxtext/configs/models/mixtral-8x7b.yml +++ b/src/maxtext/configs/models/mixtral-8x7b.yml @@ -24,8 +24,8 @@ base_num_decoder_layers: 32 head_dim: 128 mlp_activations: ["silu","linear"] vocab_size: 32000 -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false normalization_layer_epsilon: 1.0e-5 num_experts: 8 num_experts_per_tok: 2 diff --git a/src/maxtext/configs/models/olmo3-32b.yml b/src/maxtext/configs/models/olmo3-32b.yml index e9a8b82160..0c6dcabdf1 100644 --- a/src/maxtext/configs/models/olmo3-32b.yml +++ b/src/maxtext/configs/models/olmo3-32b.yml @@ -28,7 +28,7 @@ head_dim: 128 # Activations & Normalization mlp_activations: ["silu", "linear"] normalization_layer_epsilon: 1.e-6 -use_qk_norm: True +use_qk_norm: true # Attention # Layers 0,1,2 use sliding window 4096. Layer 3 uses global. Repeats. @@ -38,17 +38,17 @@ inhomogeneous_layer_cycle_interval: 4 # RoPE # Yarn RoPE for global attention, default RoPE for sliding window attention (set from olmo3.py) rope_type: "yarn" -rope_interleave: False +rope_interleave: false rope_max_timescale: 500000 # rope_theta rope_factor: 8.0 # factor so 0.1 * ln(rope_factor) + 1.0 = 1.2079441541679836 original_max_position_embeddings: 8192 beta_fast: 32.0 beta_slow: 1.0 max_position_embeddings: 65536 -rope_attention_scaling: True -rope_truncate: True # HF transformers defaults truncate=True for YaRN; matches HF correction-range floor/ceil +rope_attention_scaling: true +rope_truncate: true # HF transformers defaults truncate=true for YaRN; matches HF correction-range floor/ceil # Embeddings vocab_size: 100278 -logits_via_embedding: False -normalize_embedding_logits: False +logits_via_embedding: false +normalize_embedding_logits: false diff --git a/src/maxtext/configs/models/olmo3-7b-pt.yml b/src/maxtext/configs/models/olmo3-7b-pt.yml index 6520ad18be..bd0efa3579 100644 --- a/src/maxtext/configs/models/olmo3-7b-pt.yml +++ b/src/maxtext/configs/models/olmo3-7b-pt.yml @@ -33,7 +33,7 @@ head_dim: 128 # Activations & Normalization mlp_activations: ["silu", "linear"] # SwiGLU normalization_layer_epsilon: 1.e-6 -use_qk_norm: True +use_qk_norm: true # Attention # Layers 0,1,2 use sliding window 4096. Layer 3 uses global. Repeats. @@ -42,14 +42,14 @@ inhomogeneous_layer_cycle_interval: 4 # RoPE # Default RoPE for both global and sliding window attention. -rope_type: "default" -rope_interleave: False +rope_type: "default" +rope_interleave: false rope_max_timescale: 500000 # rope_theta max_position_embeddings: 8192 -rope_attention_scaling: True -rope_truncate: False +rope_attention_scaling: true +rope_truncate: false # Embeddings vocab_size: 100278 -logits_via_embedding: False -normalize_embedding_logits: False +logits_via_embedding: false +normalize_embedding_logits: false diff --git a/src/maxtext/configs/models/olmo3-7b.yml b/src/maxtext/configs/models/olmo3-7b.yml index 16d3a469d3..ea1e4de5f5 100644 --- a/src/maxtext/configs/models/olmo3-7b.yml +++ b/src/maxtext/configs/models/olmo3-7b.yml @@ -28,7 +28,7 @@ head_dim: 128 # Activations & Normalization mlp_activations: ["silu", "linear"] # SwiGLU normalization_layer_epsilon: 1.e-6 -use_qk_norm: True +use_qk_norm: true # Attention # Layers 0,1,2 use sliding window 4096. Layer 3 uses global. Repeats. @@ -38,17 +38,17 @@ inhomogeneous_layer_cycle_interval: 4 # RoPE # Yarn RoPE for global attention, default RoPE for sliding window attention (set from olmo3.py) rope_type: "yarn" -rope_interleave: False +rope_interleave: false rope_max_timescale: 500000 # rope_theta rope_factor: 8.0 # factor so 0.1 * ln(rope_factor) + 1.0 = 1.2079441541679836 original_max_position_embeddings: 8192 beta_fast: 32.0 beta_slow: 1.0 max_position_embeddings: 65536 -rope_attention_scaling: True -rope_truncate: True # HF transformers defaults truncate=True for YaRN; matches HF correction-range floor/ceil +rope_attention_scaling: true +rope_truncate: true # HF transformers defaults truncate=true for YaRN; matches HF correction-range floor/ceil # Embeddings vocab_size: 100278 -logits_via_embedding: False -normalize_embedding_logits: False +logits_via_embedding: false +normalize_embedding_logits: false diff --git a/src/maxtext/configs/models/qwen2.5-1.5b.yml b/src/maxtext/configs/models/qwen2.5-1.5b.yml index 1ce9a8924d..2aec43b3bb 100644 --- a/src/maxtext/configs/models/qwen2.5-1.5b.yml +++ b/src/maxtext/configs/models/qwen2.5-1.5b.yml @@ -26,9 +26,9 @@ vocab_size: 151936 decoder_block: "qwen2" normalization_layer_epsilon: 1e-06 rope_max_timescale: 1000000.0 -use_qk_norm: False +use_qk_norm: false # Bias for q, k, v proj. -attention_bias: True -logits_via_embedding: True -normalize_embedding_logits: False +attention_bias: true +logits_via_embedding: true +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen2.5-14b.yml b/src/maxtext/configs/models/qwen2.5-14b.yml index 0450fb24e4..aaed7e96bc 100644 --- a/src/maxtext/configs/models/qwen2.5-14b.yml +++ b/src/maxtext/configs/models/qwen2.5-14b.yml @@ -26,9 +26,9 @@ vocab_size: 152064 decoder_block: "qwen2" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000.0 -use_qk_norm: False +use_qk_norm: false # Bias for q, k, v proj. -attention_bias: True -logits_via_embedding: False -normalize_embedding_logits: False +attention_bias: true +logits_via_embedding: false +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen2.5-7b.yml b/src/maxtext/configs/models/qwen2.5-7b.yml index e267863619..1a1566e9ad 100644 --- a/src/maxtext/configs/models/qwen2.5-7b.yml +++ b/src/maxtext/configs/models/qwen2.5-7b.yml @@ -26,9 +26,9 @@ vocab_size: 152064 decoder_block: "qwen2" normalization_layer_epsilon: 1e-06 rope_max_timescale: 1000000.0 -use_qk_norm: False +use_qk_norm: false # Bias for q, k, v proj. -attention_bias: True -logits_via_embedding: False -normalize_embedding_logits: False +attention_bias: true +logits_via_embedding: false +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-0.6b.yml b/src/maxtext/configs/models/qwen3-0.6b.yml index e647d00352..b254d9e639 100644 --- a/src/maxtext/configs/models/qwen3-0.6b.yml +++ b/src/maxtext/configs/models/qwen3-0.6b.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-1.7b-base.yml b/src/maxtext/configs/models/qwen3-1.7b-base.yml index 9d1cd2358f..8024ec6f81 100644 --- a/src/maxtext/configs/models/qwen3-1.7b-base.yml +++ b/src/maxtext/configs/models/qwen3-1.7b-base.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-1.7b.yml b/src/maxtext/configs/models/qwen3-1.7b.yml index e3c22a10b6..40afcae767 100644 --- a/src/maxtext/configs/models/qwen3-1.7b.yml +++ b/src/maxtext/configs/models/qwen3-1.7b.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-14b-base.yml b/src/maxtext/configs/models/qwen3-14b-base.yml index aa2e007448..fb7f811f25 100644 --- a/src/maxtext/configs/models/qwen3-14b-base.yml +++ b/src/maxtext/configs/models/qwen3-14b-base.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: False # different from 0.6 and 4B variants, "tie_word_embeddings": false -normalize_embedding_logits: False +logits_via_embedding: false # different from 0.6 and 4B variants, "tie_word_embeddings": false +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-14b.yml b/src/maxtext/configs/models/qwen3-14b.yml index 3acd3d0f6d..25ecd03361 100644 --- a/src/maxtext/configs/models/qwen3-14b.yml +++ b/src/maxtext/configs/models/qwen3-14b.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: False # different from 0.6 and 4B variants, "tie_word_embeddings": false -normalize_embedding_logits: False +logits_via_embedding: false # different from 0.6 and 4B variants, "tie_word_embeddings": false +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-235b-a22b.yml b/src/maxtext/configs/models/qwen3-235b-a22b.yml index ef854d6679..5c02c85733 100644 --- a/src/maxtext/configs/models/qwen3-235b-a22b.yml +++ b/src/maxtext/configs/models/qwen3-235b-a22b.yml @@ -25,7 +25,7 @@ head_dim: 128 mlp_activations: ["silu", "linear"] vocab_size: 151936 normalization_layer_epsilon: 1.0e-6 -use_qk_norm: True +use_qk_norm: true # MoE Specific Parameters num_experts: 128 @@ -38,4 +38,4 @@ norm_topk_prob: true rope_max_timescale: 5000000 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/models/qwen3-30b-a3b-base.yml b/src/maxtext/configs/models/qwen3-30b-a3b-base.yml index 67c6a23ac4..723e4ed6d2 100644 --- a/src/maxtext/configs/models/qwen3-30b-a3b-base.yml +++ b/src/maxtext/configs/models/qwen3-30b-a3b-base.yml @@ -25,7 +25,7 @@ head_dim: 128 mlp_activations: ["silu", "linear"] vocab_size: 151936 normalization_layer_epsilon: 1.0e-6 -use_qk_norm: True +use_qk_norm: true # MoE Specific Parameters num_experts: 128 @@ -37,4 +37,4 @@ norm_topk_prob: true rope_max_timescale: 10_000_000 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/models/qwen3-30b-a3b.yml b/src/maxtext/configs/models/qwen3-30b-a3b.yml index 489e2fc9c4..d6282744d4 100644 --- a/src/maxtext/configs/models/qwen3-30b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-30b-a3b.yml @@ -25,7 +25,7 @@ head_dim: 128 mlp_activations: ["silu", "linear"] vocab_size: 151936 normalization_layer_epsilon: 1.0e-6 -use_qk_norm: True +use_qk_norm: true # MoE Specific Parameters num_experts: 128 @@ -37,4 +37,4 @@ norm_topk_prob: true rope_max_timescale: 10_000_000 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/models/qwen3-32b.yml b/src/maxtext/configs/models/qwen3-32b.yml index ff792820f2..9d7403d694 100644 --- a/src/maxtext/configs/models/qwen3-32b.yml +++ b/src/maxtext/configs/models/qwen3-32b.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: False # different from 0.6 and 4B variants, "tie_word_embeddings": false -normalize_embedding_logits: False +logits_via_embedding: false # different from 0.6 and 4B variants, "tie_word_embeddings": false +normalize_embedding_logits: false tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-4b-base.yml b/src/maxtext/configs/models/qwen3-4b-base.yml index 0ba84bc6c7..691f44312f 100644 --- a/src/maxtext/configs/models/qwen3-4b-base.yml +++ b/src/maxtext/configs/models/qwen3-4b-base.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-4b-thinking-2507.yml b/src/maxtext/configs/models/qwen3-4b-thinking-2507.yml index 214ca31c77..866bcccc94 100644 --- a/src/maxtext/configs/models/qwen3-4b-thinking-2507.yml +++ b/src/maxtext/configs/models/qwen3-4b-thinking-2507.yml @@ -28,9 +28,9 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 5000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing diff --git a/src/maxtext/configs/models/qwen3-4b.yml b/src/maxtext/configs/models/qwen3-4b.yml index cde5b670a5..730e0145cf 100644 --- a/src/maxtext/configs/models/qwen3-4b.yml +++ b/src/maxtext/configs/models/qwen3-4b.yml @@ -28,10 +28,10 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: True # from "tie_word_embeddings": true -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: true # from "tie_word_embeddings": true +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-8b-base.yml b/src/maxtext/configs/models/qwen3-8b-base.yml index 6e3d8ae6b8..420297e104 100644 --- a/src/maxtext/configs/models/qwen3-8b-base.yml +++ b/src/maxtext/configs/models/qwen3-8b-base.yml @@ -28,11 +28,11 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: False # different from smaller variants, "tie_word_embeddings": false -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: false # different from smaller variants, "tie_word_embeddings": false +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-8b.yml b/src/maxtext/configs/models/qwen3-8b.yml index eb60104607..89b7422f24 100644 --- a/src/maxtext/configs/models/qwen3-8b.yml +++ b/src/maxtext/configs/models/qwen3-8b.yml @@ -28,11 +28,11 @@ decoder_block: "qwen3" normalization_layer_epsilon: 1.0e-6 rope_max_timescale: 1000000 -use_qk_norm: True +use_qk_norm: true -logits_via_embedding: False # different from smaller variants, "tie_word_embeddings": false -normalize_embedding_logits: False -enable_dropout: False # deterministic for testing +logits_via_embedding: false # different from smaller variants, "tie_word_embeddings": false +normalize_embedding_logits: false +enable_dropout: false # deterministic for testing tokenizer_type: "huggingface" diff --git a/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml b/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml index 134d7b9ea0..8229731a06 100644 --- a/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml @@ -24,7 +24,7 @@ head_dim: 256 mlp_activations: ["silu", "linear"] vocab_size: 151936 normalization_layer_epsilon: 1.0e-6 -use_qk_norm: True +use_qk_norm: true attention_output_dim: 768 moe_expert_input_dim: 768 @@ -39,4 +39,4 @@ norm_topk_prob: true rope_max_timescale: 10_000_000 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/models/qwen3-next-80b-a3b.yml b/src/maxtext/configs/models/qwen3-next-80b-a3b.yml index ecdd9cceda..765977f1b5 100644 --- a/src/maxtext/configs/models/qwen3-next-80b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-next-80b-a3b.yml @@ -33,7 +33,7 @@ base_moe_mlp_dim: 512 num_experts: 512 shared_experts: 1 num_experts_per_tok: 10 -norm_topk_prob: True +norm_topk_prob: true # Qwen3-Next Specific Parameters for Linear Attention (Gated Delta Net) inhomogeneous_layer_cycle_interval: 4 @@ -49,4 +49,4 @@ rope_max_timescale: 10000000 partial_rotary_factor: 0.25 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml b/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml index 050df6909e..0d8b4a2fd6 100644 --- a/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml @@ -25,7 +25,7 @@ head_dim: 128 mlp_activations: ["silu", "linear"] vocab_size: 152064 normalization_layer_epsilon: 1.0e-6 -use_qk_norm: True +use_qk_norm: true # MoE Specific Parameters num_experts: 128 @@ -38,8 +38,8 @@ rope_max_timescale: 1_000_000 max_position_embeddings: 65536 # General Model Settings -enable_dropout: False -scan_layers: False # deepstack does not support scan_layers +enable_dropout: false +scan_layers: false # deepstack does not support scan_layers # Vision Encoder Configuration # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py diff --git a/src/maxtext/configs/models/qwen3.5-397b-a17b.yml b/src/maxtext/configs/models/qwen3.5-397b-a17b.yml index 1613286b4b..e22f966eb5 100644 --- a/src/maxtext/configs/models/qwen3.5-397b-a17b.yml +++ b/src/maxtext/configs/models/qwen3.5-397b-a17b.yml @@ -32,7 +32,7 @@ base_moe_mlp_dim: 1024 num_experts: 512 shared_experts: 1 num_experts_per_tok: 10 -norm_topk_prob: True +norm_topk_prob: true # GatedDeltaNet Specific Parameters for Linear Attention (GDN) inhomogeneous_layer_cycle_interval: 4 @@ -48,4 +48,4 @@ rope_max_timescale: 10000000 partial_rotary_factor: 0.25 # General Model Settings -enable_dropout: False +enable_dropout: false diff --git a/src/maxtext/configs/post_train/distillation-sft.yml b/src/maxtext/configs/post_train/distillation-sft.yml index c56c716488..774adc2bd5 100644 --- a/src/maxtext/configs/post_train/distillation-sft.yml +++ b/src/maxtext/configs/post_train/distillation-sft.yml @@ -17,8 +17,8 @@ # Inherit MaxText defaults base_config: "post_train/distillation.yml" -use_sft: True -sft_train_on_completion_only: True +use_sft: true +sft_train_on_completion_only: true # --- Dataset & Tokenizer --- hf_path: "HuggingFaceH4/ultrachat_200k" @@ -30,4 +30,4 @@ chat_template: "{% set loop_messages = messages %}{% for message in loop_message train_split: "train_sft" eval_split: "test_sft" train_data_columns: ["messages"] -eval_data_columns: ["messages"] \ No newline at end of file +eval_data_columns: ["messages"] diff --git a/src/maxtext/configs/post_train/distillation.yml b/src/maxtext/configs/post_train/distillation.yml index d2741f0742..0b2be0f84c 100644 --- a/src/maxtext/configs/post_train/distillation.yml +++ b/src/maxtext/configs/post_train/distillation.yml @@ -38,13 +38,13 @@ tokenizer_path: "meta-llama/Llama-3.1-8B" tokenizer_type: "huggingface" max_target_length: 2048 -packing: True +packing: true # --- Training Loop --- steps: 200000 checkpoint_period: 2000 log_period: 10 -save_checkpoint_on_completion: True +save_checkpoint_on_completion: true # --- Batch Size Strategy --- # Global Batch Size = per_device_batch_size * num_devices * gradient_accumulation_steps @@ -55,4 +55,4 @@ gradient_accumulation_steps: 1 learning_rate: 2.0e-4 learning_rate_schedule_steps: 200000 warmup_steps_fraction: 0.1 -learning_rate_final_fraction: 0.1 \ No newline at end of file +learning_rate_final_fraction: 0.1 diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 10582e5cfd..04920995bf 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -55,17 +55,17 @@ rl: loss_algo: 'grpo' # grpo or gspo-token # ====== Agentic Rollout ====== - # If True, uses the async AgenticGRPOLearner, which overlaps rollout generation + # If true, uses the async AgenticGRPOLearner, which overlaps rollout generation # with training for faster throughput via online vLLM inference. - use_agentic_rollout: False + use_agentic_rollout: false # Max concurrent rollout requests when using agentic rollout. max_concurrency: 256 # Number of off-policy steps tolerated before requiring a policy update. off_policy_steps: 0 # System prompt injected into the agent at rollout time. system_prompt: '' - # If True, mask degenerate groups (all-zero advantages) from contributing to the loss. - degenerate_group_masking: True + # If true, mask degenerate groups (all-zero advantages) from contributing to the loss. + degenerate_group_masking: true # Upper-bound clipping epsilon for GRPO loss; defaults to grpo_epsilon when null. epsilon_high: null # Number of model keys to chunk for resharding tensors between trainer and rollout devices. @@ -95,18 +95,18 @@ decoder_layer_input: 'offload' query_proj: 'offload' key_proj: 'offload' value_proj: 'offload' -checkpoint_storage_use_ocdbt: False # For Pathways -checkpoint_storage_use_zarr3: False # For Pathways -use_pathways: True +checkpoint_storage_use_ocdbt: false # For Pathways +checkpoint_storage_use_zarr3: false # For Pathways +use_pathways: true log_period: 20 -convert_checkpoint_if_possible: True +convert_checkpoint_if_possible: true # ====== Debugging ====== debug: - rl: True -# If True, Tunix-managed metrics measurement will be enabled. The metrics will be + rl: true +# If true, Tunix-managed metrics measurement will be enabled. The metrics will be # uploaded to tensorboard. -enable_tunix_perf_metrics: False +enable_tunix_perf_metrics: false # ====== Training ====== batch_size: 1 @@ -149,8 +149,8 @@ generation_configs: eval_top_p: 1.0 num_eval_passes: 1 # Number of generation passes during evaluation -eval_corr_lst: False # If True, only include correct responses in the list during evaluation -eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation +eval_corr_lst: false # If true, only include correct responses in the list during evaluation +eval_make_lst: false # If true, return a list of (question, answer, responses) during evaluation eval_mode: "pass" # Evaluation mode ("pass" for pass@K, "maj" for majority voting maj@K, "pass_at_1" for pass@1 estimation) # ====== Inference ====== @@ -168,18 +168,18 @@ decode_sampling_temperature: 0.9 decode_sampling_top_k: 50 decode_sampling_nucleus_p: 1.0 # Optional sharding configuration for samplers -enable_dp_attention: False +enable_dp_attention: false # Performance tuning for samplers max_num_batched_tokens: null max_num_seqs: null -# If True, enables asynchronous scheduling in vLLM for faster generation -async_scheduling: True +# If true, enables asynchronous scheduling in vLLM for faster generation +async_scheduling: true # stop generation when any of these strings is generated stop_strings: null # ====== Checkpoint Configuration ====== -enable_checkpointing: True -async_checkpointing: False +enable_checkpointing: true +async_checkpointing: false checkpoint_period: 50 max_num_checkpoints_to_keep: 10 @@ -206,7 +206,7 @@ reasoning_end_token: '' solution_start_token: '' solution_end_token: '' chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' -skip_jax_distributed_system: True +skip_jax_distributed_system: true # ====== Dataset Configuration ====== # Supported values for dataset_name: @@ -231,5 +231,5 @@ tokenizer_type: 'huggingface' ##### MaxText to VLLM Converter validation parameters vllm_load_format: dummy # Format to load the model for conversion. Options are "auto", "dummy" -debug_converter: False # If True, run key coverage check, weight stats, and GCS upload then exit without generation. -gcs_debug_path: "" # If set and debug_converter=True, upload converted layer-0 and global tensors as .npy to this GCS prefix. +debug_converter: false # If true, run key coverage check, weight stats, and GCS upload then exit without generation. +gcs_debug_path: "" # If set and debug_converter=true, upload converted layer-0 and global tensors as .npy to this GCS prefix. diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index e9d5108e23..8e7c0098d1 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -74,4 +74,4 @@ logical_axis_rules: [ # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] -return_log_prob: True \ No newline at end of file +return_log_prob: true diff --git a/src/maxtext/configs/post_train/sft-vision-chartqa.yml b/src/maxtext/configs/post_train/sft-vision-chartqa.yml index 7dfb5cc51d..e4e32eb539 100644 --- a/src/maxtext/configs/post_train/sft-vision-chartqa.yml +++ b/src/maxtext/configs/post_train/sft-vision-chartqa.yml @@ -14,13 +14,13 @@ base_config: "base.yml" -use_sft: True -use_tunix_gradient_accumulation: True -use_multimodal: True +use_sft: true +use_tunix_gradient_accumulation: true +use_multimodal: true # For vision, the prompt contains image, we only train on completion tokens -sft_train_on_completion_only: True -packing: False # packing is not supported yet -freeze_vision_encoder_params: True +sft_train_on_completion_only: true +packing: false # packing is not supported yet +freeze_vision_encoder_params: true learning_rate: 2.e-5 # -------------- HF pipeline -------------- diff --git a/src/maxtext/configs/post_train/sft-vision-slidevqa.yml b/src/maxtext/configs/post_train/sft-vision-slidevqa.yml index e2eaa7af17..f77d5910e6 100644 --- a/src/maxtext/configs/post_train/sft-vision-slidevqa.yml +++ b/src/maxtext/configs/post_train/sft-vision-slidevqa.yml @@ -14,13 +14,13 @@ base_config: "base.yml" -use_sft: True -use_tunix_gradient_accumulation: True -use_multimodal: True +use_sft: true +use_tunix_gradient_accumulation: true +use_multimodal: true # For vision, the prompt contains image, we only train on completion tokens -sft_train_on_completion_only: True -packing: False # packing is not supported yet -freeze_vision_encoder_params: True +sft_train_on_completion_only: true +packing: false # packing is not supported yet +freeze_vision_encoder_params: true learning_rate: 2.e-5 # -------------- HF pipeline -------------- diff --git a/src/maxtext/configs/post_train/sft.yml b/src/maxtext/configs/post_train/sft.yml index 3ba5cf2161..1188e5ea84 100644 --- a/src/maxtext/configs/post_train/sft.yml +++ b/src/maxtext/configs/post_train/sft.yml @@ -14,16 +14,16 @@ base_config: "base.yml" -use_sft: True -use_tunix_gradient_accumulation: True -# sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise -sft_train_on_completion_only: True -packing: True +use_sft: true +use_tunix_gradient_accumulation: true +# sft_train_on_completion_only=false trains on both prompt and completion tokens; trains only on completion tokens otherwise +sft_train_on_completion_only: true +packing: true learning_rate: 2.e-5 # -------------- LoRA / QLoRA -------------- lora: - enable_lora: False + enable_lora: false lora_rank: 0 lora_alpha: 0.0 lora_module_path: "" diff --git a/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml b/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml index 121efe248e..dba4bc03ce 100644 --- a/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml +++ b/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml @@ -7,9 +7,9 @@ base_config: "inference/inference_jetstream.yml" model_name: "llama2-70b" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml b/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml index b91bb85fb3..b71b7990f1 100644 --- a/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml +++ b/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml @@ -8,10 +8,10 @@ base_config: "inference/inference_jetstream.yml" model_name: "llama3.1-405b" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true tokenizer_path: "assets/tokenizer_llama3.tiktoken" # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml b/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml index b3ca2d1465..525d30e30c 100644 --- a/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml +++ b/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml @@ -8,9 +8,9 @@ model_name: "llama3-70b" tokenizer_path: "assets/tokenizer_llama3.tiktoken" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], diff --git a/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml b/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml index 165500da8d..68f6c839e8 100644 --- a/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml +++ b/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml @@ -8,9 +8,9 @@ base_config: "inference/inference_jetstream.yml" sharding_strategy: "experimental" attention: 'dot_product' -allow_split_physical_axes: True +allow_split_physical_axes: true # Used to replicate the quantization scale to avoid the inefficient XLA fusion. -replicate_quant_scale: True +replicate_quant_scale: true logical_axis_rules: [ ['embed', []], @@ -42,8 +42,8 @@ logical_axis_rules: [ decoder_block: "llama4" mlp_activations: ["silu","linear"] -enable_dropout: False -logits_via_embedding: False +enable_dropout: false +logits_via_embedding: false tokenizer_type: "huggingface" tokenizer_path: "meta-llama/Llama-4-Maverick-17B-128E" @@ -57,15 +57,15 @@ vocab_size: 202048 normalization_layer_epsilon: 1e-05 rope_max_timescale: 500000 rope_type: "llama3.1" -rope_use_scale: False +rope_use_scale: false num_experts: 128 capacity_factor: -1.0 # TODO: this will be removed once we support dropless with megablox/ragged_dot shared_experts: 1 num_experts_per_tok: 1 -use_qk_norm: False +use_qk_norm: false nope_layer_interval: 4 # Every fourth layer should NOT use RoPE interleave_moe_layer_step: 2 # Every 2nd layer is MoE layer, and 1st layer is dense layer # TODO: delete the following variables once we add support for dropless with megablox/ragged_dot -sparse_matmul: False -megablox: False +sparse_matmul: false +megablox: false diff --git a/src/maxtext/experimental/rl/grpo.yml b/src/maxtext/experimental/rl/grpo.yml index d1ba5d22fd..ab2c122d5d 100644 --- a/src/maxtext/experimental/rl/grpo.yml +++ b/src/maxtext/experimental/rl/grpo.yml @@ -1,6 +1,6 @@ base_config: "base.yml" -use_grpo: True +use_grpo: true train_data_columns: 'prompt' learning_rate: 1.e-6 @@ -26,12 +26,12 @@ async_checkpointing: false # Pathways inference inference_devices_per_replica: 4 inference_replicas: 1 -use_pathways_reshard: True +use_pathways_reshard: true -return_log_prob: True +return_log_prob: true -add_bos: False -add_eos: False +add_bos: false +add_eos: false ### Splash attention block sizes # These values are tuned for small sequence lengths used in the grpo test script. @@ -43,7 +43,7 @@ sa_block_kv_dkv: 128 sa_block_kv_dkv_compute: 128 sa_block_q_dq: 128 sa_block_kv_dq: 128 -sa_use_fused_bwd_kernel: False +sa_use_fused_bwd_kernel: false sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" diff --git a/src/maxtext/experimental/rl/grpo_inference.yml b/src/maxtext/experimental/rl/grpo_inference.yml index de7f67ae33..cc64ce4ed1 100644 --- a/src/maxtext/experimental/rl/grpo_inference.yml +++ b/src/maxtext/experimental/rl/grpo_inference.yml @@ -1,8 +1,8 @@ -# This config is used for loading the inference model for GRPO. Note that base_config is set to "rl.yml" +# This config is used for loading the inference model for GRPO. Note that base_config is set to "rl.yml" # to inherit the necessary sharding rules for GRPO inference. base_config: "rl_mt_jt.yml" -use_grpo: True +use_grpo: true train_data_columns: 'prompt' attention: 'dot_product' @@ -20,10 +20,10 @@ decode_sampling_strategy: "weighted" decode_sampling_temperature: 0.9 async_checkpointing: false -return_log_prob: True +return_log_prob: true -add_bos: False -add_eos: False +add_bos: false +add_eos: false ### Splash attention block sizes # These values are tuned for small sequence lengths used in the grpo test script. @@ -35,7 +35,7 @@ sa_block_kv_dkv: 128 sa_block_kv_dkv_compute: 128 sa_block_q_dq: 128 sa_block_kv_dq: 128 -sa_use_fused_bwd_kernel: False +sa_use_fused_bwd_kernel: false sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" diff --git a/src/maxtext/experimental/rl/grpo_trainer_test.yml b/src/maxtext/experimental/rl/grpo_trainer_test.yml index eb30031756..094fdee999 100644 --- a/src/maxtext/experimental/rl/grpo_trainer_test.yml +++ b/src/maxtext/experimental/rl/grpo_trainer_test.yml @@ -8,7 +8,7 @@ max_prefill_predict_length: 16 dataset_type: "synthetic" dtype: "float32" matmul_precision: "high" -logits_dot_in_fp32: True +logits_dot_in_fp32: true prompt: "Hello world this is a test" init_weights_seed: 42 -enable_dropout: False +enable_dropout: false