Skip to content

Add fp32_lm_head flag for vLLM precision parity#526

Draft
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_fp32_lm_head
Draft

Add fp32_lm_head flag for vLLM precision parity#526
jlamypoirier wants to merge 1 commit into
mainfrom
jlp_fp32_lm_head

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

Summary

Adds an fp32_lm_head field on LanguageModelHeadConfig. When True, the LM head linear's input and weight are upcast to FP32 before the matmul, matching vLLM's bf16_last_layer_fp32 quantization. This lets the trainer compute log-probabilities at the same numerical precision as the actor's sampling, so the importance-sampling ratio starts near 1.0 instead of being artificially inflated by a trainer/actor precision mismatch.

The detached FP32 weight has requires_grad=False, which makes output_parallel_linear_backward skip the weight-grad path. The FSDP gradient contract is restored by computing grad_weight = grad.t() @ saved_input explicitly and accumulating into the original BF16 param's grad_buffer via accumulate_gradient.

Off by default — disabled path is byte-identical to before.

Test plan

  • pytest tests/layers/test_lm_head.py — passes

Originally part of #502.

When True, upcasts the LM head linear's input and weight to FP32 before
the matmul, matching vLLM's bf16_last_layer_fp32 quantization. This lets
the trainer compute log-probabilities at the same numerical precision as
the actor's sampling, so the importance-sampling ratio starts near 1.0
instead of being inflated by trainer/actor precision mismatch.

The detached FP32 weight has requires_grad=False, which makes
output_parallel_linear_backward skip the weight-grad path. The FSDP
gradient contract is restored by computing grad_weight explicitly and
accumulating into the original BF16 param's grad_buffer via
accumulate_gradient.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jlamypoirier added a commit that referenced this pull request May 28, 2026
Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`,
the LM head linear's input and weight are upcast to FP32 before the
matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets
the trainer compute log-probabilities at the same numerical precision
as the actor's sampling, so the importance-sampling ratio starts near
1.0 instead of being artificially inflated by a trainer/actor precision
mismatch.

The detached FP32 weight has `requires_grad=False`, which makes
`output_parallel_linear_backward` skip the weight-grad path. The FSDP
gradient contract is restored by computing `grad_weight = grad.t() @ saved_input`
explicitly and accumulating into the original BF16 param's `grad_buffer`
via `accumulate_gradient`.

Off by default — disabled path is byte-identical to before.

Cherry-picked from #526 to unblock the precision-evaluation tool's
GSPO smoke test, which compares fp32_lm_head=true vs false.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant