diff --git a/tests/unit/rl_utils_test.py b/tests/unit/rl_utils_test.py
index 2fd04f93f6..98920ab71b 100644
--- a/tests/unit/rl_utils_test.py
+++ b/tests/unit/rl_utils_test.py
@@ -16,6 +16,7 @@
import unittest
import pytest
+from unittest import mock
from types import SimpleNamespace
evaluate_rl = pytest.importorskip(
@@ -330,5 +331,112 @@ def test_without_hash(self):
self.assertIsNone(utils_rl.extract_hash_answer(""))
+class TestCheckAnswer(unittest.TestCase):
+ """Tests for utils_rl.check_answer."""
+
+ def setUp(self):
+ self.config = _make_config()
+
+ @pytest.mark.cpu_only
+ def test_exact_match(self):
+ """Test when the guess exactly matches the answer."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["r42"],
+ answer=["42"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.reward_exact_format_match])
+
+ @pytest.mark.cpu_only
+ def test_whitespace_match(self):
+ """Test when the guess matches the answer after stripping whitespace."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["r 42 "],
+ answer=["42"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.reward_white_space_format_match])
+
+ @pytest.mark.cpu_only
+ def test_ratio_high_match(self):
+ """Test when the guess is within the high ratio threshold (0.9 to 1.1)."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["r10.5"],
+ answer=["10"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_high])
+
+ @pytest.mark.cpu_only
+ def test_ratio_low_match(self):
+ """Test when the guess is within the low ratio threshold (0.8 to 1.2)."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["r11.5"],
+ answer=["10"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_low])
+
+ @pytest.mark.cpu_only
+ def test_incorrect_answer(self):
+ """Test when the guess is outside the acceptable ratio thresholds."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["r15"],
+ answer=["10"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.penalty_incorrect_answer])
+
+ @pytest.mark.cpu_only
+ def test_no_format_match(self):
+ """Test when the completion does not match the expected format."""
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=["Just some random text without the right tags."],
+ answer=["42"],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [0])
+
+ @pytest.mark.cpu_only
+ @mock.patch.object(utils_rl, "math_verify_func")
+ @mock.patch.object(utils_rl, "parse")
+ def test_dataset_specific_normalization(self, mock_parse, mock_math_verify):
+ """Test that specific datasets trigger normalize_final_answer."""
+ # Mock math_verify and parse to fail, so we only rely on exact string match
+ # after normalization.
+ mock_math_verify.return_value = [0.0]
+ mock_parse.side_effect = Exception("parse failed")
+
+ complex_guess = r"x=$\text{\textbf{\overline{\boxed{\frac12 \sqrt3}}}}$"
+ expected_normalized = r"\frac{1}{2} \sqrt{3}"
+
+ # These datasets should trigger normalization and match exactly
+ for dataset in ["DAPO-Math-17k", "OpenMathInstruct-2"]:
+ self.config.dataset_name = dataset
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=[f"r{complex_guess}"],
+ answer=[expected_normalized],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.reward_exact_format_match])
+
+ # Other datasets should not normalize, leading to a mismatch/parse error
+ self.config.dataset_name = "gsm8k"
+ scores = utils_rl.check_answer(
+ prompts=[""],
+ completions=[f"r{complex_guess}"],
+ answer=[expected_normalized],
+ tmvp_config=self.config,
+ )
+ self.assertEqual(scores, [self.config.penalty_incorrect_format])
+
+
if __name__ == "__main__":
unittest.main()