Skip to content

[WIP] Add gemma4 drafter model support#2240

Draft
athitten wants to merge 2 commits into
mainfrom
athitten/gemma4_drafter_support
Draft

[WIP] Add gemma4 drafter model support#2240
athitten wants to merge 2 commits into
mainfrom
athitten/gemma4_drafter_support

Conversation

@athitten
Copy link
Copy Markdown
Contributor

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@athitten athitten changed the title Add gemma4 drafter model support [WIP] Add gemma4 drafter model support May 15, 2026
Apply fixes for joint base + drafter training:

* Drop ``use_cache=False`` override in ``composite.forward``. Without the
  ``DynamicCache``, HF's sliding-window mask path silently degrades
  (SDPA mask-skip can collapse sliding layers into plain causal attention),
  inflating the initial training loss. The YAML's
  ``text_config.use_cache: true`` now takes effect.
* Change drafter label shift from ``k + 1`` to ``k``. The VLM collate
  pre-shifts labels by 1 so ``labels[t] == input_ids[t + 1]``; the prior
  ``k + 1`` shift was training the drafter to predict ``input_ids[t + 2]``
  instead of ``input_ids[t + 1]``.
* Add hard asserts: ``cp_size == 1`` and ``torch_dtype == bfloat16`` in
  ``Gemma4WithDrafter.from_pretrained``.
* Add plan knobs: ``freeze_base_for_drafter``, ``share_embedding_with_base``
  (one-shot init copy; FSDP2-safe), ``base_activation_checkpointing``.
* Recipe: factor joint loss into ``FinetuneRecipeForVLM._maybe_add_drafter_loss``,
  gate log on ``is_remote_logging_step`` (was per-microbatch), and make
  validation drafter-aware so ``val_loss`` reflects drafter drift.
* Remove dead ``from_pretrained`` override in drafter wrapper.
* Drop redundant ``text_config.output_hidden_states`` from YAML; expand the
  ``use_cache: true`` comment to explain the real reason (sliding-window mask,
  not KV sharing).
* Add ``test_post_collate_semantic_alignment`` that pins the label-shift
  convention so a future regression to ``k + 1`` fails loudly. Refine
  ``test_drafter_loss_reaches_drafter_params`` to reflect that
  ``post_projection`` only sees gradient in multi-step chains.

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant