diff --git a/tests/unit/rl_utils_test.py b/tests/unit/rl_utils_test.py index 2fd04f93f6..98920ab71b 100644 --- a/tests/unit/rl_utils_test.py +++ b/tests/unit/rl_utils_test.py @@ -16,6 +16,7 @@ import unittest import pytest +from unittest import mock from types import SimpleNamespace evaluate_rl = pytest.importorskip( @@ -330,5 +331,112 @@ def test_without_hash(self): self.assertIsNone(utils_rl.extract_hash_answer("")) +class TestCheckAnswer(unittest.TestCase): + """Tests for utils_rl.check_answer.""" + + def setUp(self): + self.config = _make_config() + + @pytest.mark.cpu_only + def test_exact_match(self): + """Test when the guess exactly matches the answer.""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["r42"], + answer=["42"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.reward_exact_format_match]) + + @pytest.mark.cpu_only + def test_whitespace_match(self): + """Test when the guess matches the answer after stripping whitespace.""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["r 42 "], + answer=["42"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.reward_white_space_format_match]) + + @pytest.mark.cpu_only + def test_ratio_high_match(self): + """Test when the guess is within the high ratio threshold (0.9 to 1.1).""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["r10.5"], + answer=["10"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_high]) + + @pytest.mark.cpu_only + def test_ratio_low_match(self): + """Test when the guess is within the low ratio threshold (0.8 to 1.2).""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["r11.5"], + answer=["10"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.reward_ratio_guess_to_answer_low]) + + @pytest.mark.cpu_only + def test_incorrect_answer(self): + """Test when the guess is outside the acceptable ratio thresholds.""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["r15"], + answer=["10"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.penalty_incorrect_answer]) + + @pytest.mark.cpu_only + def test_no_format_match(self): + """Test when the completion does not match the expected format.""" + scores = utils_rl.check_answer( + prompts=[""], + completions=["Just some random text without the right tags."], + answer=["42"], + tmvp_config=self.config, + ) + self.assertEqual(scores, [0]) + + @pytest.mark.cpu_only + @mock.patch.object(utils_rl, "math_verify_func") + @mock.patch.object(utils_rl, "parse") + def test_dataset_specific_normalization(self, mock_parse, mock_math_verify): + """Test that specific datasets trigger normalize_final_answer.""" + # Mock math_verify and parse to fail, so we only rely on exact string match + # after normalization. + mock_math_verify.return_value = [0.0] + mock_parse.side_effect = Exception("parse failed") + + complex_guess = r"x=$\text{\textbf{\overline{\boxed{\frac12 \sqrt3}}}}$" + expected_normalized = r"\frac{1}{2} \sqrt{3}" + + # These datasets should trigger normalization and match exactly + for dataset in ["DAPO-Math-17k", "OpenMathInstruct-2"]: + self.config.dataset_name = dataset + scores = utils_rl.check_answer( + prompts=[""], + completions=[f"r{complex_guess}"], + answer=[expected_normalized], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.reward_exact_format_match]) + + # Other datasets should not normalize, leading to a mismatch/parse error + self.config.dataset_name = "gsm8k" + scores = utils_rl.check_answer( + prompts=[""], + completions=[f"r{complex_guess}"], + answer=[expected_normalized], + tmvp_config=self.config, + ) + self.assertEqual(scores, [self.config.penalty_incorrect_format]) + + if __name__ == "__main__": unittest.main()