From 1e655886323caaa4426d2d6953cdd630f203decf Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Fri, 19 Jun 2026 21:07:32 +0800 Subject: [PATCH] fix(sampling_params): validate token-id input types via a shared helper StopSequence, AllowedTokenIds and InvalidTokenIds each duplicated the same "validate a list of ints fits a fixed c_int array, then copy" logic, and two of the three did it incorrectly: - AllowedTokenIds.initialize asserted isinstance(e, int) over self.ids, the zero-filled ctypes destination array, instead of the caller-supplied ids. That guard can never fail, so a non-int entry in allowed_token_ids slipped past the intended "all must be int" check and surfaced an opaque ctypes TypeError at the array assignment instead. - InvalidTokenIds.initialize performed no element-type check at all. Extract a single store_int_token_ids helper that validates the input list (length and element type) and copies it into the destination, and route all three initializers through it. Valid inputs round-trip identically; add regression tests covering the type and overflow guards. --- lightllm/server/core/objs/sampling_params.py | 39 ++++++--- .../core/objs/test_token_ids_validation.py | 79 +++++++++++++++++++ 2 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 unit_tests/server/core/objs/test_token_ids_validation.py diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index c39559f5f6..de6f5fbcbd 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -20,6 +20,29 @@ INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) +def store_int_token_ids(dst, ids: List[int], max_length: int, name: str) -> int: + """Validate a caller-supplied list of token ids and copy it into a fixed-size ctypes ``c_int`` array. + + The type check runs against the input ``ids`` (not the zero-filled destination buffer), so a non-int + entry fails fast with a clear message instead of surfacing an opaque ctypes ``TypeError`` at the + assignment below. + + Args: + dst: destination ctypes array, declared as ``c_int * max_length``. + ids: caller-supplied token ids; every element must be an ``int``. + max_length: capacity of ``dst``. + name: field name used in the error messages. + + Returns: + The number of ids written, i.e. ``len(ids)``. + """ + size = len(ids) + assert size <= max_length, f"Too many {name}: {size} > {max_length}." + assert all(isinstance(e, int) for e in ids), f"all {name} must be int." + dst[:size] = ids[:] + return size + + class StopSequence(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -30,10 +53,7 @@ class StopSequence(ctypes.Structure): ] def initialize(self, sequence: List[int], sequence_str: Optional[str] = None): - self.size = len(sequence) - assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long." - assert all(isinstance(e, int) for e in sequence), "all must be int" - self.sequence[: self.size] = sequence[:] + self.size = store_int_token_ids(self.sequence, sequence, STOP_SEQUENCE_MAX_LENGTH, "stop token ids") if sequence_str is not None: sequence_str_bytes = sequence_str.encode("utf-8") @@ -197,10 +217,7 @@ class AllowedTokenIds(ctypes.Structure): ] def initialize(self, ids: List[int]): - self.size = len(ids) - assert self.size <= ALLOWED_TOKEN_IDS_MAX_LENGTH, "Too many allowed token IDs." - assert all(isinstance(e, int) for e in self.ids), "all must be int" - self.ids[: self.size] = ids[:] + self.size = store_int_token_ids(self.ids, ids, ALLOWED_TOKEN_IDS_MAX_LENGTH, "allowed token ids") def to_list(self): return list(self.ids[: self.size]) @@ -214,11 +231,7 @@ class InvalidTokenIds(ctypes.Structure): ] def initialize(self, ids: List[int]): - self.size = len(ids) - assert ( - self.size <= INVALID_TOKEN_IDS_MAX_LENGTH - ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." - self.ids[: self.size] = ids[:] + self.size = store_int_token_ids(self.ids, ids, INVALID_TOKEN_IDS_MAX_LENGTH, "invalid token ids") return def to_list(self): diff --git a/unit_tests/server/core/objs/test_token_ids_validation.py b/unit_tests/server/core/objs/test_token_ids_validation.py new file mode 100644 index 0000000000..f669dbe088 --- /dev/null +++ b/unit_tests/server/core/objs/test_token_ids_validation.py @@ -0,0 +1,79 @@ +import pytest +from lightllm.server.core.objs.sampling_params import ( + StopSequence, + AllowedTokenIds, + InvalidTokenIds, + store_int_token_ids, + STOP_SEQUENCE_MAX_LENGTH, + ALLOWED_TOKEN_IDS_MAX_LENGTH, + INVALID_TOKEN_IDS_MAX_LENGTH, +) + + +def test_allowed_token_ids_accepts_valid_ints(): + allowed_ids = AllowedTokenIds() + allowed_ids.initialize([1, 2, 3]) + assert allowed_ids.size == 3 + assert allowed_ids.to_list() == [1, 2, 3] + + +@pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]]) +def test_allowed_token_ids_rejects_non_int(bad_ids): + # A non-int entry must fail with the explicit "all must be int" guard, + # not slip past validation into an opaque ctypes TypeError. + allowed_ids = AllowedTokenIds() + with pytest.raises(AssertionError): + allowed_ids.initialize(bad_ids) + + +def test_allowed_token_ids_rejects_too_many(): + allowed_ids = AllowedTokenIds() + with pytest.raises(AssertionError): + allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1)) + + +def test_invalid_token_ids_accepts_valid_ints(): + invalid_ids = InvalidTokenIds() + invalid_ids.initialize([4, 5, 6]) + assert invalid_ids.size == 3 + assert invalid_ids.to_list() == [4, 5, 6] + + +@pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]]) +def test_invalid_token_ids_rejects_non_int(bad_ids): + invalid_ids = InvalidTokenIds() + with pytest.raises(AssertionError): + invalid_ids.initialize(bad_ids) + + +def test_invalid_token_ids_rejects_too_many(): + invalid_ids = InvalidTokenIds() + with pytest.raises(AssertionError): + invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1)) + + +def test_stop_sequence_rejects_non_int(): + seq = StopSequence() + with pytest.raises(AssertionError): + seq.initialize([1, "2"]) + + +def test_store_int_token_ids_returns_size_and_writes_buffer(): + import ctypes + + buf = (ctypes.c_int * 8)() + size = store_int_token_ids(buf, [7, 8, 9], 8, "test ids") + assert size == 3 + assert list(buf[:size]) == [7, 8, 9] + + +def test_store_int_token_ids_rejects_overflow(): + import ctypes + + buf = (ctypes.c_int * 2)() + with pytest.raises(AssertionError): + store_int_token_ids(buf, [1, 2, 3], 2, "test ids") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])