Skip to content

Add Gemma 4 E2B / E4B (text) support to MaxText#3904

Open
gagika wants to merge 1 commit into
mainfrom
agagik-gemma
Open

Add Gemma 4 E2B / E4B (text) support to MaxText#3904
gagika wants to merge 1 commit into
mainfrom
agagik-gemma

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented May 14, 2026

Description

Adds Gemma 4 small variants — E2B and E4B (text-only) — to MaxText.

These are the smaller members of the Gemma 4 family. They share the
broader Gemma 4 attention / norm structure but introduce two new features
that drive their parameter efficiency:

  • Per-Layer Embeddings (PLE). Each decoder layer consumes a per-layer
    slice of an extra embedding tensor injected by a new Gemma4SmallPLE
    block. Controlled by hidden_size_per_layer_input /
    vocab_size_per_layer_input.
  • KV sharing. The last num_kv_shared_layers decoder layers reuse
    K / V from the most recent non-shared layer of the same attention type
    (sliding↔sliding, full↔full). E2B additionally widens the MLP on those
    shared layers (use_double_wide_mlp: true) to compensate for the
    missing parameters.

Both features carry per-layer state that is not expressible inside
nn.scan, so a new GEMMA4_SMALL DecoderBlockType is added with its
own non-scanned execution path (Decoder._apply_gemma4_small_layers).
The model validator enforces scan_layers=False for these variants.

What's included

  • New model file src/maxtext/models/gemma4_small.py (PLE + attention
    with optional KV sharing + decoder layer).
  • New configs configs/models/gemma4-e2b.yml and gemma4-e4b.yml.
  • HF round-trip: hf_model_configs.py, hf_shape.py,
    param_mapping.py updated to handle PLE params, KV-shared layers, and
    the (optional) double-wide MLP.
  • TFLOP/MFU accounting: calculate_gemma4_small_tflops_training_per_device.
  • Config plumbing: DecoderBlockType.GEMMA4_SMALL, four new
    Attention fields in configs/types.py, base.yml defaults, and
    validation that rejects scan_layers=true / use_multimodal=true for
    E2B / E4B.

Out of scope

  • Multimodal. E2B / E4B ship a vision tower in their HF configs, but
    MaxText support for the gemma4-small vision encoder (clipped linears
    in particular) is not in this PR. use_multimodal=true is rejected by
    the validator with a clear error.
  • Scanned layers. Per-layer KV sharing isn't expressible under
    nn.scan; rejected by the validator.

Tests

  • New unit tests:
    • tests/unit/gemma4_small_test.py — attention-pattern dispatch,
      layer-type tuples, KV donor/shared-layer mapping for both variants.
    • tests/unit/flop_calculation_test.py::test_calculate_gemma4_small_tflops_*
      closed-form TFLOP accounting matching the layer/donor structure.
    • tests/unit/configs_test.py — E2B / E4B yml configs are loaded by
      the existing config-instantiation sweep.
  • End-to-end forward-pass logit checks:
    • tests/end_to_end/tpu/gemma4/e2b/{convert_gemma4,convert_gemma4_pt}.sh
    • tests/end_to_end/tpu/gemma4/e4b/{convert_gemma4,convert_gemma4_pt}.sh
    • Each converts the HF checkpoint with to_maxtext, then runs
      forward_pass_logit_checker against the HF model with
      --max_kl_div=0.03. This is the recommended smoke test after
      touching the model code, param map, or either YAML.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

Copy link
Copy Markdown
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the rapid implementation! I wonder if you have test results from forward_pass_logit_checker? Also do you add some unit tests for comparison with torch on the new modules such as Gemma4SmallPLE, Gemma4SmallAttention, Gemma4SmallDecoderLayer? https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/unit/gemma4_layers_test.py This was added recently.

Comment thread src/maxtext/models/gemma4_small.py Outdated
Comment thread src/maxtext/models/gemma4_small.py Outdated
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I agree with @aireenmei that it would be good if we could reuse rope/attention, along with component-wise unit tests (potentially as follow-up). Some minor comments.

Comment thread src/maxtext/models/gemma4_small.py Outdated
Comment thread src/maxtext/models/gemma4_small.py Outdated
Comment thread src/maxtext/models/gemma4_small.py Outdated
Comment thread src/maxtext/checkpoint_conversion/utils/param_mapping.py Outdated
Comment thread src/maxtext/checkpoint_conversion/utils/param_mapping.py Outdated
Comment thread tests/end_to_end/tpu/gemma4/e2b/convert_gemma4_base.sh
Comment thread tests/end_to_end/tpu/gemma4/Run_Gemma4.md Outdated
@github-actions
Copy link
Copy Markdown

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @shuningjin, but I was unable to process your request. Please see the logs for more details.

@gagika
Copy link
Copy Markdown
Collaborator Author

gagika commented May 19, 2026

Thanks for the rapid implementation! I wonder if you have test results from forward_pass_logit_checker? Also do you add some unit tests for comparison with torch on the new modules such as Gemma4SmallPLE, Gemma4SmallAttention, Gemma4SmallDecoderLayer? https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/unit/gemma4_layers_test.py This was added recently.

Added those unit test, PTL

for forward logits test, yes, I have done for both models.
I did again after addressing PR feedback:

https://paste.googleplex.com/5790253452492800
https://paste.googleplex.com/6349200758538240

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants