Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig):
hint=FieldHint.architecture,
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)
fp32_lm_head: bool = Field(
default=False,
desc="Upcast input and weight to float32 before the lm_head linear. "
"Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs "
"are computed at the same numerical precision, keeping the IS ratio near 1 at init.",
hint=FieldHint.feature,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
Expand Down
34 changes: 28 additions & 6 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.tensor import TensorMeta
from fast_llm.tensor import TensorMeta, accumulate_gradient
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial(
split_index: int = 0,
return_logits: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if self._config.fp32_lm_head:
input_dtype = input_.dtype
input_ = input_.to(torch.float32)
# detach → requires_grad=False → output_parallel_linear_backward skips weight grad
weight = self.output_weights.detach().to(torch.float32)
else:
weight = self.output_weights

logits, context = output_parallel_linear_forward(
input_=input_,
weight=self.output_weights,
weight=weight,
bias=None,
group=self._parallel_dim.group if self._vocab_parallel else None,
sequence_parallel=self._sequence_parallel and self._vocab_parallel,
Expand Down Expand Up @@ -285,12 +293,26 @@ def _logits_loss_forward_backward_partial(
if loss_value is not None:
losses_.append(loss_value.detach())

if grad is not None and self._config.final_logit_softcap is not None:
if not self.training or grad is None:
return sum(losses_) if losses_ else None, None

if self._config.final_logit_softcap is not None:
grad = _softcap_backward(grad, logits, self._config.final_logit_softcap)

return sum(losses_) if losses_ else None, (
output_parallel_linear_backward(grad, context) if self.training else None
)
input_grad = output_parallel_linear_backward(grad, context)
if self._config.fp32_lm_head:
# Weight grad was skipped because weight.requires_grad=False; accumulate manually.
# context: (input_, weight, bias, group, sequence_parallel, ...)
saved_input = context[0]
if context[4]: # sequence_parallel
from fast_llm.core.ops import gather_op

saved_input = gather_op(saved_input, context[3], dim=0)
grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2))
accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype))
input_grad = input_grad.to(input_dtype)

return sum(losses_) if losses_ else None, input_grad

def get_loss_definitions(self) -> list[LossDef]:
return [
Expand Down
Loading