diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f531a1d46..a6057d67f 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -14,6 +14,7 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -166,3 +167,18 @@ def loss_class(self) -> "type[LanguageModelZLoss]": from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss return LanguageModelZLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelLossConfig): + + _abstract: typing.ClassVar[bool] = False + + epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") + epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + + @property + def loss_class(self) -> "type[LanguageModelGRPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + return LanguageModelGRPOLoss diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index 15c4c788c..d6bca1c52 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -54,12 +54,12 @@ def dpo_loss( # TODO: Make more efficient. logits = logits * logits_scale_factor - policy_log_probabilities = _get_target_log_probabilities(logits, targets) + policy_log_probabilities = get_target_log_probabilities(logits, targets) policy_log_ratios = _get_target_log_probability_for_spans( policy_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets) + reference_log_probabilities = get_target_log_probabilities(reference_model_logits.float().detach(), targets) reference_log_ratios = _get_target_log_probability_for_spans( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) @@ -68,14 +68,17 @@ def dpo_loss( return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean() -def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): - # Gather log probabilities corresponding to the target tokens - return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): return sum( log_probabilities[sample_index, begin:end].sum() for sample_index, sample_spans in enumerate(spans) for begin, end in sample_spans ) + + +@torch.compile +def get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + # Avoid negative (masked) labels. + targets = targets * (targets >= 0) + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py new file mode 100644 index 000000000..8b01a5248 --- /dev/null +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -0,0 +1,60 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.dpo import get_target_log_probabilities +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Support vocab_parallel + if self._prediction_distance > 0: + raise NotImplementedError() + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return loss_forward_backward( + self._get_grad_output(kwargs), + grpo_loss, + logits, + self._get_loss_mask(kwargs, split_index), + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], kwargs, split_index), + self._config.epsilon_low, + self._config.epsilon_high, + self._logits_scale_factor, + ) + + +@torch.compile +def grpo_loss( + logits: torch.Tensor, + loss_mask: "torch.Tensor | None", + labels: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + probability_ratio = torch.exp(get_target_log_probabilities(logits, labels) - old_log_probabilities) + loss = -torch.min( + probability_ratio * advantages, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, + ) + if loss_mask is not None: + loss = loss * loss_mask + return loss.mean() diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 840e3846d..6471a516f 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,13 +1,10 @@ -import numpy as np import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.utils import Assert -from tests.utils.dataset import get_random_spans def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): @@ -18,59 +15,6 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans ) -def reference_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, - beta: float, -) -> torch.Tensor: - # TODO: Too similar to the actual implementation. - policy_log_probs = ( - torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - ) - policy_chosen_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - policy_rejected_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - reference_log_probs = ( - torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) - .gather(dim=-1, index=targets.unsqueeze(-1)) - .squeeze(-1) - ) - reference_chosen_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - reference_rejected_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() - - -def test_dpo_loss(): - logits = torch.normal(0, 1, (10, 50, 100)) - reference_model_logits = torch.normal(0, 1, (10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) - spans = get_random_spans(np.full(10, 50), 0, 10) - - fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2]) - reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) - Assert.rms_close(fastllm_loss, reference_loss, 1e-5) - - @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 1d08986f8..cb42a941d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,8 +9,10 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.loss.config import LanguageModelLossKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert +from tests.layers.test_lm_losses import reference_grpo_loss from tests.utils.utils import get_base_model, get_stage SEQUENCE_LENGTH = 200 @@ -25,6 +27,7 @@ class LMHeadTestConfig: label_loss: bool | float = False distillation_loss: bool | float = False z_loss: bool | float = False + grpo_loss: bool | float = False logits_scale_factor: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False @@ -38,7 +41,10 @@ class LMHeadTestConfig: def actual_label_loss(self): return ( True - if self.label_loss is False and self.distillation_loss is False and self.z_loss is False + if self.label_loss is False + and self.distillation_loss is False + and self.z_loss is False + and self.grpo_loss is False else self.label_loss ) @@ -61,6 +67,10 @@ def get_config(self) -> GPTModelConfig: losses["z_loss"] = {"type": "z_loss"} if isinstance(self.z_loss, float): losses["z_loss"]["weight"] = self.z_loss + if self.grpo_loss is not False: + losses["grpo_loss"] = {"type": "grpo"} + if isinstance(self.grpo_loss, float): + losses["grpo_loss"]["weight"] = self.grpo_loss if losses: head_config["losses"] = losses @@ -108,7 +118,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: } if self.loss_masking: kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) - if self.actual_label_loss is not False: + if self.actual_label_loss is not False or self.grpo_loss is not False: labels = torch.randint( 0, VOCAB_SIZE, @@ -117,7 +127,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device=device, ) if LanguageModelKwargs.loss_mask in kwargs: - labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100) kwargs[LanguageModelKwargs.labels] = labels if self.distillation_loss is not False: @@ -127,6 +137,19 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: dtype=input_.dtype, device=device, ) + + if self.grpo_loss is not False: + kwargs[LanguageModelLossKwargs.advantages] = torch.randn( + input_.shape[:-1], + dtype=torch.float32, + device=device, + ) + kwargs[LanguageModelLossKwargs.old_log_probabilities] = torch.randn( + input_.shape[:-1], + dtype=torch.float32, + device=device, + ) + return input_, kwargs def get_reference_outputs( @@ -152,7 +175,7 @@ def get_reference_outputs( total_loss = 0 losses = {} - if self.actual_label_loss is not False: + if self.actual_label_loss is not False or self.grpo_loss is not False: if self.sequence_first: labels = kwargs[LanguageModelKwargs.labels][ head._prediction_distance : head._prediction_distance + logits.size(0) @@ -161,6 +184,7 @@ def get_reference_outputs( labels = kwargs[LanguageModelKwargs.labels][ :, head._prediction_distance : head._prediction_distance + logits.size(1) ] + if self.actual_label_loss is not False: label_loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), labels.flatten(), reduction="none" ).mean() @@ -187,6 +211,17 @@ def get_reference_outputs( losses["z_loss"] = z_loss.detach() total_loss = total_loss + float(self.z_loss) * z_loss + if self.grpo_loss is not False: + grpo_loss = reference_grpo_loss( + logits, + kwargs.get(LanguageModelKwargs.loss_mask), + labels, + kwargs[LanguageModelLossKwargs.advantages], + kwargs[LanguageModelLossKwargs.old_log_probabilities], + ) + losses["grpo_loss"] = grpo_loss.detach() + total_loss = total_loss + float(self.grpo_loss) * grpo_loss + total_loss.backward() if len(losses) > 1: @@ -227,6 +262,7 @@ def _add_configs(base_name: str, **kwargs): _add_configs("label_loss", label_loss=True) _add_configs("distillation_loss", distillation_loss=True) _add_configs("z_loss", z_loss=True) +_add_configs("grpo_loss", grpo_loss=True) _add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True) _add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5) _add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py new file mode 100644 index 000000000..c10e3110b --- /dev/null +++ b/tests/layers/test_lm_losses.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +import torch + +from fast_llm.layers.language_model.loss.dpo import dpo_loss +from fast_llm.layers.language_model.loss.grpo import grpo_loss +from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans + +VOCAB_SIZE = 100 +NUM_TOKENS = 200 + + +def reference_dpo_loss( + logits: torch.Tensor, + labels: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, +) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=labels.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() + + +@pytest.mark.skip(reason="DPO loss is broken") +def test_dpo_loss(): + logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + reference_model_logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + spans = get_random_spans(np.full(10, 50), 0, 10) + + fast_llm_loss = dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2]) + reference_loss = reference_dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2], beta=1) + Assert.rms_close(fast_llm_loss, reference_loss, 1e-5) + + +def reference_grpo_loss( + logits: torch.Tensor, + loss_mask: "torch.Tensor | None", + labels: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, +) -> torch.Tensor: + # Log probabilities. + if loss_mask is not None: + labels = labels * loss_mask + target_log_probabilities = ( + torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ) + probability_ratio = torch.exp(target_log_probabilities - old_log_probabilities) + loss = -torch.min( + probability_ratio * advantages, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, + ) + if loss_mask is not None: + loss = loss * loss_mask + return loss.mean() + + +def test_grpo_loss(): + logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + loss_mask = torch.randint(0, 2, (NUM_TOKENS,), dtype=torch.bool) + labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + advantages = torch.normal(0, 1, (NUM_TOKENS,)) + old_log_probabilities = torch.normal(0, 1, (NUM_TOKENS,)) + fast_llm_loss = grpo_loss(logits, loss_mask, labels, advantages, old_log_probabilities) + reference_loss = reference_grpo_loss(logits, loss_mask, labels, advantages, old_log_probabilities) + Assert.rms_close(fast_llm_loss, reference_loss, 1e-5)