Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
c335f6e
train with only layer distillation losses
oleksost Dec 16, 2025
e06a4b2
unscaled loss llogging + training with distillation loss factor = 0
oleksost Dec 16, 2025
179ae25
make logging more explicit
oleksost Dec 17, 2025
af456f0
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 17, 2025
9968aac
clean + tests
oleksost Dec 17, 2025
945c5a7
nvm
oleksost Dec 17, 2025
4b6e3d7
forward KL
oleksost Dec 19, 2025
c5fefa0
test forward kl
oleksost Dec 19, 2025
4119596
wip: report unscaled + kl loss
oleksost Dec 19, 2025
b55a0a4
loss config
oleksost Dec 22, 2025
097baeb
wip
oleksost Dec 22, 2025
d773d98
tests
oleksost Dec 22, 2025
35400c1
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 22, 2025
282925c
test
oleksost Dec 22, 2025
0f73ea2
tests
oleksost Dec 22, 2025
04a0193
Merge branch 'main' into train_only_layer_losses
oleksost Dec 22, 2025
fa85c41
wip
oleksost Dec 22, 2025
feb416e
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 22, 2025
31cfb84
wip
oleksost Dec 23, 2025
24fe67b
no grad if factor 0
oleksost Dec 23, 2025
00f6118
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 23, 2025
0cadf98
Merge branch 'main' into train_only_layer_losses
oleksost Dec 23, 2025
0e562e9
addressed comments
oleksost Dec 23, 2025
2a474e2
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 23, 2025
52c1c11
addressed comments
oleksost Dec 23, 2025
406d0a2
Removed Targets class
oleksost Dec 30, 2025
f25380a
fixes
oleksost Dec 30, 2025
8adb7dd
imports
oleksost Dec 30, 2025
1ce641d
polish naming
oleksost Jan 6, 2026
95f14af
addresseing comments
oleksost Jan 8, 2026
5ad4c0c
explicit z_loss grads
oleksost Jan 8, 2026
0a66e14
removed z_loss as aux loss
oleksost Jan 8, 2026
f8f7041
move loss configs to the lm config
oleksost Jan 8, 2026
ab9c917
tests
oleksost Jan 8, 2026
89470dc
Merge branch 'main' into train_only_layer_losses
oleksost Jan 9, 2026
66f16d5
stuff
jlamypoirier Jan 9, 2026
80a4b93
Merge remote-tracking branch 'origin/main' into jlp_cpu
jlamypoirier Jan 9, 2026
f144b87
fixes
jlamypoirier Jan 10, 2026
6e54c93
comments
oleksost Jan 12, 2026
305244f
fixes
jlamypoirier Jan 12, 2026
f71319f
fix
jlamypoirier Jan 13, 2026
43c58bf
fix
jlamypoirier Jan 13, 2026
8137b8c
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
jlamypoirier Jan 13, 2026
3c8f3c2
misc
jlamypoirier Jan 13, 2026
705c482
fix
jlamypoirier Jan 13, 2026
53a467f
Merge branch 'main' into jlp_cpu
jlamypoirier Jan 13, 2026
b156f4e
fix
jlamypoirier Jan 13, 2026
4fbc7a8
stuff
jlamypoirier Jan 13, 2026
99a73b5
fixes
jlamypoirier Jan 13, 2026
fb679d1
stuff
jlamypoirier Jan 16, 2026
764d636
Merge remote-tracking branch 'origin/main' into jlp_cpu
jlamypoirier Jan 16, 2026
7f96009
fix
jlamypoirier Jan 16, 2026
982f945
Merge remote-tracking branch 'origin/main' into jlp_entropy_loss
jlamypoirier Jan 16, 2026
3c8ce50
Merge branch 'main' into train_only_layer_losses
jlamypoirier Jan 16, 2026
63ac004
Merge remote-tracking branch 'origin/train_only_layer_losses' into jl…
jlamypoirier Jan 16, 2026
afc33f3
stuff
jlamypoirier Jan 16, 2026
d81f1f6
Merge branch 'main' into jlp_cpu
jlamypoirier Jan 16, 2026
f8dcce6
stuff
jlamypoirier Jan 16, 2026
f96c72f
stuff
jlamypoirier Jan 16, 2026
98ee4fb
Merge branch 'jlp_cpu' into jlp_entropy_loss
jlamypoirier Jan 16, 2026
2a4362f
fixes
jlamypoirier Jan 16, 2026
b464e4e
fixes
jlamypoirier Jan 16, 2026
ba40a40
fixes
jlamypoirier Jan 19, 2026
a2ff5fb
fixes
jlamypoirier Jan 19, 2026
44bad56
fixes
jlamypoirier Jan 19, 2026
e626a03
fix
jlamypoirier Jan 20, 2026
baa0944
GRPO loss
jlamypoirier Jan 20, 2026
966e151
GRPO loss
jlamypoirier Jan 20, 2026
7b24f94
simplify
jlamypoirier Jan 20, 2026
978be16
simplify
jlamypoirier Jan 20, 2026
9d25147
simplify
jlamypoirier Jan 20, 2026
58a3191
Loss class
jlamypoirier Jan 20, 2026
5f245a8
Loss class
jlamypoirier Jan 20, 2026
e10cf4d
stuff
jlamypoirier Jan 21, 2026
336560e
stuff
jlamypoirier Jan 21, 2026
89bda84
stuff
jlamypoirier Jan 21, 2026
58f1316
Merge branch 'jlp_entropy_loss' into jlp_grpo
jlamypoirier Jan 21, 2026
ed57346
fixes
jlamypoirier Jan 21, 2026
19e4d0f
Merge branch 'main' into jlp_cpu
oleksost Jan 22, 2026
0c57174
Merge branch 'jlp_cpu' into jlp_entropy_loss
jlamypoirier Jan 22, 2026
24a8d4d
Merge remote-tracking branch 'origin/main' into jlp_entropy_loss
jlamypoirier Jan 22, 2026
777406d
Merge branch 'jlp_entropy_loss' into jlp_grpo
jlamypoirier Jan 22, 2026
a8842b0
Merge remote-tracking branch 'origin/main' into jlp_grpo
jlamypoirier Jan 22, 2026
2794fd1
fix
jlamypoirier Jan 22, 2026
fa9b9d2
Merge branch 'jlp_pipeline_rl' into jlp_grpo
jlamypoirier Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions fast_llm/layers/language_model/loss/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LanguageModelDistillationLoss,
LanguageModelLabelEntropyLoss,
)
from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss

Expand Down Expand Up @@ -166,3 +167,18 @@ def loss_class(self) -> "type[LanguageModelZLoss]":
from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss

return LanguageModelZLoss


@config_class(dynamic_type={LanguageModelLossConfig: "grpo"})
class LanguageModelGRPOLossConfig(LanguageModelLossConfig):

_abstract: typing.ClassVar[bool] = False

epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs")
epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs")

@property
def loss_class(self) -> "type[LanguageModelGRPOLoss]":
from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss

return LanguageModelGRPOLoss
17 changes: 10 additions & 7 deletions fast_llm/layers/language_model/loss/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def dpo_loss(
# TODO: Make more efficient.
logits = logits * logits_scale_factor

policy_log_probabilities = _get_target_log_probabilities(logits, targets)
policy_log_probabilities = get_target_log_probabilities(logits, targets)
policy_log_ratios = _get_target_log_probability_for_spans(
policy_log_probabilities, chosen_spans
) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans)

reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets)
reference_log_probabilities = get_target_log_probabilities(reference_model_logits.float().detach(), targets)
reference_log_ratios = _get_target_log_probability_for_spans(
reference_log_probabilities, chosen_spans
) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans)
Expand All @@ -68,14 +68,17 @@ def dpo_loss(
return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean()


def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor):
# Gather log probabilities corresponding to the target tokens
return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)


def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]):
return sum(
log_probabilities[sample_index, begin:end].sum()
for sample_index, sample_spans in enumerate(spans)
for begin, end in sample_spans
)


@torch.compile
def get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Avoid negative (masked) labels.
targets = targets * (targets >= 0)
# Gather log probabilities corresponding to the target tokens
return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
60 changes: 60 additions & 0 deletions fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import typing

import torch

from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs
from fast_llm.layers.language_model.loss.dpo import get_target_log_probabilities
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward


class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Support vocab_parallel
if self._prediction_distance > 0:
raise NotImplementedError()
if self._vocab_parallel:
raise NotImplementedError()

def forward_backward(
self,
logits: "torch.Tensor",
kwargs: dict[str, typing.Any],
split_index: int = 0,
) -> "tuple[torch.Tensor, torch.Tensor | None]":
return loss_forward_backward(
self._get_grad_output(kwargs),
grpo_loss,
logits,
self._get_loss_mask(kwargs, split_index),
self._get_labels(kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], kwargs, split_index),
self._config.epsilon_low,
self._config.epsilon_high,
self._logits_scale_factor,
)


@torch.compile
def grpo_loss(
logits: torch.Tensor,
loss_mask: "torch.Tensor | None",
labels: torch.Tensor,
advantages: torch.Tensor,
old_log_probabilities: torch.Tensor,
epsilon_low: float = 0.2,
epsilon_high: float = 0.2,
logits_scale_factor: float = 1.0,
) -> torch.Tensor:
if logits_scale_factor != 1.0:
# TODO: Make more efficient.
logits = logits * logits_scale_factor
probability_ratio = torch.exp(get_target_log_probabilities(logits, labels) - old_log_probabilities)
loss = -torch.min(
probability_ratio * advantages,
torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages,
)
if loss_mask is not None:
loss = loss * loss_mask
return loss.mean()
56 changes: 0 additions & 56 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import numpy as np
import pytest
import torch

from fast_llm.functional.config import ActivationType, MLPRecomputeLevel
from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation
from fast_llm.functional.triton.sparse_copy import get_sparse_map
from fast_llm.layers.language_model.loss.dpo import dpo_loss
from fast_llm.utils import Assert
from tests.utils.dataset import get_random_spans


def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]):
Expand All @@ -18,59 +15,6 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans
)


def reference_dpo_loss(
logits: torch.Tensor,
targets: torch.Tensor,
reference_model_logits: torch.Tensor,
chosen_spans: torch.Tensor,
rejected_spans: torch.Tensor,
beta: float,
) -> torch.Tensor:
# TODO: Too similar to the actual implementation.
policy_log_probs = (
torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
)
policy_chosen_logps = sum(
policy_log_probs[sample_index, begin:end].sum()
for sample_index, sample_spans in enumerate(chosen_spans)
for begin, end in sample_spans
)
policy_rejected_logps = sum(
policy_log_probs[sample_index, begin:end].sum()
for sample_index, sample_spans in enumerate(rejected_spans)
for begin, end in sample_spans
)
reference_log_probs = (
torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1)
.gather(dim=-1, index=targets.unsqueeze(-1))
.squeeze(-1)
)
reference_chosen_logps = sum(
reference_log_probs[sample_index, begin:end].sum()
for sample_index, sample_spans in enumerate(chosen_spans)
for begin, end in sample_spans
)
reference_rejected_logps = sum(
reference_log_probs[sample_index, begin:end].sum()
for sample_index, sample_spans in enumerate(rejected_spans)
for begin, end in sample_spans
)
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean()


def test_dpo_loss():
logits = torch.normal(0, 1, (10, 50, 100))
reference_model_logits = torch.normal(0, 1, (10, 50, 100))
targets = torch.randint(0, 100, (10, 50))
spans = get_random_spans(np.full(10, 50), 0, 10)

fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2])
reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1)
Assert.rms_close(fastllm_loss, reference_loss, 1e-5)


@pytest.mark.parametrize("gated", [True, False])
@pytest.mark.parametrize(
"activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu]
Expand Down
44 changes: 40 additions & 4 deletions tests/layers/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from fast_llm.layers.attention.config import AttentionKwargs
from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs
from fast_llm.layers.language_model.head import LanguageModelHead
from fast_llm.layers.language_model.loss.config import LanguageModelLossKwargs
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.utils import Assert
from tests.layers.test_lm_losses import reference_grpo_loss
from tests.utils.utils import get_base_model, get_stage

SEQUENCE_LENGTH = 200
Expand All @@ -25,6 +27,7 @@ class LMHeadTestConfig:
label_loss: bool | float = False
distillation_loss: bool | float = False
z_loss: bool | float = False
grpo_loss: bool | float = False
logits_scale_factor: float = 1.0
compute_dtype: DataType = DataType.float32
full_precision_residual: bool = False
Expand All @@ -38,7 +41,10 @@ class LMHeadTestConfig:
def actual_label_loss(self):
return (
True
if self.label_loss is False and self.distillation_loss is False and self.z_loss is False
if self.label_loss is False
and self.distillation_loss is False
and self.z_loss is False
and self.grpo_loss is False
else self.label_loss
)

Expand All @@ -61,6 +67,10 @@ def get_config(self) -> GPTModelConfig:
losses["z_loss"] = {"type": "z_loss"}
if isinstance(self.z_loss, float):
losses["z_loss"]["weight"] = self.z_loss
if self.grpo_loss is not False:
losses["grpo_loss"] = {"type": "grpo"}
if isinstance(self.grpo_loss, float):
losses["grpo_loss"]["weight"] = self.grpo_loss
if losses:
head_config["losses"] = losses

Expand Down Expand Up @@ -108,7 +118,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
}
if self.loss_masking:
kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device)
if self.actual_label_loss is not False:
if self.actual_label_loss is not False or self.grpo_loss is not False:
labels = torch.randint(
0,
VOCAB_SIZE,
Expand All @@ -117,7 +127,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
device=device,
)
if LanguageModelKwargs.loss_mask in kwargs:
labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels)
labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100)
kwargs[LanguageModelKwargs.labels] = labels

if self.distillation_loss is not False:
Expand All @@ -127,6 +137,19 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
dtype=input_.dtype,
device=device,
)

if self.grpo_loss is not False:
kwargs[LanguageModelLossKwargs.advantages] = torch.randn(
input_.shape[:-1],
dtype=torch.float32,
device=device,
)
kwargs[LanguageModelLossKwargs.old_log_probabilities] = torch.randn(
input_.shape[:-1],
dtype=torch.float32,
device=device,
)

return input_, kwargs

def get_reference_outputs(
Expand All @@ -152,7 +175,7 @@ def get_reference_outputs(
total_loss = 0
losses = {}

if self.actual_label_loss is not False:
if self.actual_label_loss is not False or self.grpo_loss is not False:
if self.sequence_first:
labels = kwargs[LanguageModelKwargs.labels][
head._prediction_distance : head._prediction_distance + logits.size(0)
Expand All @@ -161,6 +184,7 @@ def get_reference_outputs(
labels = kwargs[LanguageModelKwargs.labels][
:, head._prediction_distance : head._prediction_distance + logits.size(1)
]
if self.actual_label_loss is not False:
label_loss = torch.nn.functional.cross_entropy(
logits.flatten(0, -2), labels.flatten(), reduction="none"
).mean()
Expand All @@ -187,6 +211,17 @@ def get_reference_outputs(
losses["z_loss"] = z_loss.detach()
total_loss = total_loss + float(self.z_loss) * z_loss

if self.grpo_loss is not False:
grpo_loss = reference_grpo_loss(
logits,
kwargs.get(LanguageModelKwargs.loss_mask),
labels,
kwargs[LanguageModelLossKwargs.advantages],
kwargs[LanguageModelLossKwargs.old_log_probabilities],
)
losses["grpo_loss"] = grpo_loss.detach()
total_loss = total_loss + float(self.grpo_loss) * grpo_loss

total_loss.backward()

if len(losses) > 1:
Expand Down Expand Up @@ -227,6 +262,7 @@ def _add_configs(base_name: str, **kwargs):
_add_configs("label_loss", label_loss=True)
_add_configs("distillation_loss", distillation_loss=True)
_add_configs("z_loss", z_loss=True)
_add_configs("grpo_loss", grpo_loss=True)
_add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True)
_add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5)
_add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0)
Expand Down
Loading
Loading