fix(sampling_params): validate token-id input types via a shared helper#1365
fix(sampling_params): validate token-id input types via a shared helper#1365SuperMarioYL wants to merge 1 commit into
Conversation
StopSequence, AllowedTokenIds and InvalidTokenIds each duplicated the same "validate a list of ints fits a fixed c_int array, then copy" logic, and two of the three did it incorrectly: - AllowedTokenIds.initialize asserted isinstance(e, int) over self.ids, the zero-filled ctypes destination array, instead of the caller-supplied ids. That guard can never fail, so a non-int entry in allowed_token_ids slipped past the intended "all must be int" check and surfaced an opaque ctypes TypeError at the array assignment instead. - InvalidTokenIds.initialize performed no element-type check at all. Extract a single store_int_token_ids helper that validates the input list (length and element type) and copies it into the destination, and route all three initializers through it. Valid inputs round-trip identically; add regression tests covering the type and overflow guards.
There was a problem hiding this comment.
Code Review
This pull request introduces a helper function store_int_token_ids to centralize and deduplicate token ID validation and ctypes array copying, refactoring StopSequence, AllowedTokenIds, and InvalidTokenIds to use it. It also adds a comprehensive suite of unit tests. The review feedback suggests replacing assert statements with standard exceptions (ValueError and TypeError) to prevent validation from being bypassed when Python is run with optimization flags. Additionally, it recommends using type(e) is int instead of isinstance(e, int) to strictly reject boolean values, and updating the corresponding unit tests to reflect these changes.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| size = len(ids) | ||
| assert size <= max_length, f"Too many {name}: {size} > {max_length}." | ||
| assert all(isinstance(e, int) for e in ids), f"all {name} must be int." | ||
| dst[:size] = ids[:] | ||
| return size |
There was a problem hiding this comment.
Using assert statements for runtime validation of input arguments is discouraged because assertions can be globally disabled in Python when run with optimization flags (e.g., python -O). Instead, standard exceptions like ValueError and TypeError should be raised.
Additionally, isinstance(e, int) returns True for boolean values (e.g., True, False) because bool is a subclass of int in Python. To strictly validate that the token IDs are integers and not booleans, we should use type(e) is int.
| size = len(ids) | |
| assert size <= max_length, f"Too many {name}: {size} > {max_length}." | |
| assert all(isinstance(e, int) for e in ids), f"all {name} must be int." | |
| dst[:size] = ids[:] | |
| return size | |
| size = len(ids) | |
| if size > max_length: | |
| raise ValueError(f"Too many {name}: {size} > {max_length}.") | |
| if not all(type(e) is int for e in ids): | |
| raise TypeError(f"all {name} must be int.") | |
| dst[:size] = ids[:] | |
| return size |
References
- Do not use assert statements for runtime validation of arguments, as they can be compiled out when Python is run with optimization flags (-O).
| @pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]]) | ||
| def test_allowed_token_ids_rejects_non_int(bad_ids): | ||
| # A non-int entry must fail with the explicit "all must be int" guard, | ||
| # not slip past validation into an opaque ctypes TypeError. | ||
| allowed_ids = AllowedTokenIds() | ||
| with pytest.raises(AssertionError): | ||
| allowed_ids.initialize(bad_ids) |
There was a problem hiding this comment.
Update the test to expect TypeError instead of AssertionError following the validation improvement in store_int_token_ids. Also, add a boolean value (e.g., [1, True]) to the parameterized test cases to verify that booleans are correctly rejected.
| @pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]]) | |
| def test_allowed_token_ids_rejects_non_int(bad_ids): | |
| # A non-int entry must fail with the explicit "all must be int" guard, | |
| # not slip past validation into an opaque ctypes TypeError. | |
| allowed_ids = AllowedTokenIds() | |
| with pytest.raises(AssertionError): | |
| allowed_ids.initialize(bad_ids) | |
| @pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1], [1, True]]) | |
| def test_allowed_token_ids_rejects_non_int(bad_ids): | |
| # A non-int entry must fail with the explicit "all must be int" guard, | |
| # not slip past validation into an opaque ctypes TypeError. | |
| allowed_ids = AllowedTokenIds() | |
| with pytest.raises(TypeError): | |
| allowed_ids.initialize(bad_ids) |
| def test_allowed_token_ids_rejects_too_many(): | ||
| allowed_ids = AllowedTokenIds() | ||
| with pytest.raises(AssertionError): | ||
| allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1)) |
There was a problem hiding this comment.
Update the test to expect ValueError instead of AssertionError for length overflow validation.
| def test_allowed_token_ids_rejects_too_many(): | |
| allowed_ids = AllowedTokenIds() | |
| with pytest.raises(AssertionError): | |
| allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1)) | |
| def test_allowed_token_ids_rejects_too_many(): | |
| allowed_ids = AllowedTokenIds() | |
| with pytest.raises(ValueError): | |
| allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1)) |
| @pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]]) | ||
| def test_invalid_token_ids_rejects_non_int(bad_ids): | ||
| invalid_ids = InvalidTokenIds() | ||
| with pytest.raises(AssertionError): | ||
| invalid_ids.initialize(bad_ids) |
There was a problem hiding this comment.
Update the test to expect TypeError instead of AssertionError and add a boolean test case (e.g., [4, False]).
| @pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]]) | |
| def test_invalid_token_ids_rejects_non_int(bad_ids): | |
| invalid_ids = InvalidTokenIds() | |
| with pytest.raises(AssertionError): | |
| invalid_ids.initialize(bad_ids) | |
| @pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0], [4, False]]) | |
| def test_invalid_token_ids_rejects_non_int(bad_ids): | |
| invalid_ids = InvalidTokenIds() | |
| with pytest.raises(TypeError): | |
| invalid_ids.initialize(bad_ids) |
| def test_invalid_token_ids_rejects_too_many(): | ||
| invalid_ids = InvalidTokenIds() | ||
| with pytest.raises(AssertionError): | ||
| invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1)) |
There was a problem hiding this comment.
Update the test to expect ValueError instead of AssertionError.
| def test_invalid_token_ids_rejects_too_many(): | |
| invalid_ids = InvalidTokenIds() | |
| with pytest.raises(AssertionError): | |
| invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1)) | |
| def test_invalid_token_ids_rejects_too_many(): | |
| invalid_ids = InvalidTokenIds() | |
| with pytest.raises(ValueError): | |
| invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1)) |
| def test_stop_sequence_rejects_non_int(): | ||
| seq = StopSequence() | ||
| with pytest.raises(AssertionError): | ||
| seq.initialize([1, "2"]) |
There was a problem hiding this comment.
Update the test to expect TypeError instead of AssertionError.
| def test_stop_sequence_rejects_non_int(): | |
| seq = StopSequence() | |
| with pytest.raises(AssertionError): | |
| seq.initialize([1, "2"]) | |
| def test_stop_sequence_rejects_non_int(): | |
| seq = StopSequence() | |
| with pytest.raises(TypeError): | |
| seq.initialize([1, "2"]) |
| def test_store_int_token_ids_rejects_overflow(): | ||
| import ctypes | ||
|
|
||
| buf = (ctypes.c_int * 2)() | ||
| with pytest.raises(AssertionError): | ||
| store_int_token_ids(buf, [1, 2, 3], 2, "test ids") |
There was a problem hiding this comment.
Update the test to expect ValueError instead of AssertionError.
| def test_store_int_token_ids_rejects_overflow(): | |
| import ctypes | |
| buf = (ctypes.c_int * 2)() | |
| with pytest.raises(AssertionError): | |
| store_int_token_ids(buf, [1, 2, 3], 2, "test ids") | |
| def test_store_int_token_ids_rejects_overflow(): | |
| import ctypes | |
| buf = (ctypes.c_int * 2)() | |
| with pytest.raises(ValueError): | |
| store_int_token_ids(buf, [1, 2, 3], 2, "test ids") |
Summary
StopSequence,AllowedTokenIds, andInvalidTokenIdseach duplicate the same logic — validate that a list of ints fits a fixed-sizec_intarray, then copy it in — and two of the three implement the validation incorrectly:AllowedTokenIds.initializeassertsisinstance(e, int)overself.ids, the zero-filled ctypes destination buffer (c_int * ALLOWED_TOKEN_IDS_MAX_LENGTH), instead of the caller-suppliedids. Every element of that buffer is already anint, so the guard can never fail. A non-int entry inallowed_token_idstherefore slips past the intended"all must be int"check and instead surfaces an opaqueTypeError: 'str' object cannot be interpreted as an integerat the array assignment. It also iterates 256 elements regardless of the actual input size.InvalidTokenIds.initializeperforms no element-type check at all.StopSequence.initializealready validates the input correctly — this is the reference behavior the other two should match.Change
Extract a single helper,
store_int_token_ids(dst, ids, max_length, name), that validates the input list (length + element type) with clear messages and copies it into the destination array, then route all three initializers through it.AllowedTokenIds.InvalidTokenIds.Valid inputs round-trip identically (
to_list()/to_string()unchanged); the only observable change is that invalid (non-int) input now fails fast with a clearAssertionErrorinstead of an opaque ctypesTypeError.Tests
Added
unit_tests/server/core/objs/test_token_ids_validation.pycovering, for all three structures, that valid ids are accepted and round-trip, that a non-int element is rejected, and that exceeding the capacity is rejected — plus direct tests of the helper. The non-int cases are red before this change (they raiseTypeError, notAssertionError) and green after.The new tests live in a separate file to avoid overlapping with the in-flight test changes in #1350.