Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,29 @@
INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10))


def store_int_token_ids(dst, ids: List[int], max_length: int, name: str) -> int:
"""Validate a caller-supplied list of token ids and copy it into a fixed-size ctypes ``c_int`` array.

The type check runs against the input ``ids`` (not the zero-filled destination buffer), so a non-int
entry fails fast with a clear message instead of surfacing an opaque ctypes ``TypeError`` at the
assignment below.

Args:
dst: destination ctypes array, declared as ``c_int * max_length``.
ids: caller-supplied token ids; every element must be an ``int``.
max_length: capacity of ``dst``.
name: field name used in the error messages.

Returns:
The number of ids written, i.e. ``len(ids)``.
"""
size = len(ids)
assert size <= max_length, f"Too many {name}: {size} > {max_length}."
assert all(isinstance(e, int) for e in ids), f"all {name} must be int."
dst[:size] = ids[:]
return size
Comment on lines +39 to +43

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime validation of input arguments is discouraged because assertions can be globally disabled in Python when run with optimization flags (e.g., python -O). Instead, standard exceptions like ValueError and TypeError should be raised.

Additionally, isinstance(e, int) returns True for boolean values (e.g., True, False) because bool is a subclass of int in Python. To strictly validate that the token IDs are integers and not booleans, we should use type(e) is int.

Suggested change
size = len(ids)
assert size <= max_length, f"Too many {name}: {size} > {max_length}."
assert all(isinstance(e, int) for e in ids), f"all {name} must be int."
dst[:size] = ids[:]
return size
size = len(ids)
if size > max_length:
raise ValueError(f"Too many {name}: {size} > {max_length}.")
if not all(type(e) is int for e in ids):
raise TypeError(f"all {name} must be int.")
dst[:size] = ids[:]
return size
References
  1. Do not use assert statements for runtime validation of arguments, as they can be compiled out when Python is run with optimization flags (-O).



class StopSequence(ctypes.Structure):
_pack_ = 4
_fields_ = [
Expand All @@ -30,10 +53,7 @@ class StopSequence(ctypes.Structure):
]

def initialize(self, sequence: List[int], sequence_str: Optional[str] = None):
self.size = len(sequence)
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
assert all(isinstance(e, int) for e in sequence), "all must be int"
self.sequence[: self.size] = sequence[:]
self.size = store_int_token_ids(self.sequence, sequence, STOP_SEQUENCE_MAX_LENGTH, "stop token ids")

if sequence_str is not None:
sequence_str_bytes = sequence_str.encode("utf-8")
Expand Down Expand Up @@ -197,10 +217,7 @@ class AllowedTokenIds(ctypes.Structure):
]

def initialize(self, ids: List[int]):
self.size = len(ids)
assert self.size <= ALLOWED_TOKEN_IDS_MAX_LENGTH, "Too many allowed token IDs."
assert all(isinstance(e, int) for e in self.ids), "all must be int"
self.ids[: self.size] = ids[:]
self.size = store_int_token_ids(self.ids, ids, ALLOWED_TOKEN_IDS_MAX_LENGTH, "allowed token ids")

def to_list(self):
return list(self.ids[: self.size])
Expand All @@ -214,11 +231,7 @@ class InvalidTokenIds(ctypes.Structure):
]

def initialize(self, ids: List[int]):
self.size = len(ids)
assert (
self.size <= INVALID_TOKEN_IDS_MAX_LENGTH
), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}."
self.ids[: self.size] = ids[:]
self.size = store_int_token_ids(self.ids, ids, INVALID_TOKEN_IDS_MAX_LENGTH, "invalid token ids")
return

def to_list(self):
Expand Down
79 changes: 79 additions & 0 deletions unit_tests/server/core/objs/test_token_ids_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
from lightllm.server.core.objs.sampling_params import (
StopSequence,
AllowedTokenIds,
InvalidTokenIds,
store_int_token_ids,
STOP_SEQUENCE_MAX_LENGTH,
ALLOWED_TOKEN_IDS_MAX_LENGTH,
INVALID_TOKEN_IDS_MAX_LENGTH,
)


def test_allowed_token_ids_accepts_valid_ints():
allowed_ids = AllowedTokenIds()
allowed_ids.initialize([1, 2, 3])
assert allowed_ids.size == 3
assert allowed_ids.to_list() == [1, 2, 3]


@pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]])
def test_allowed_token_ids_rejects_non_int(bad_ids):
# A non-int entry must fail with the explicit "all must be int" guard,
# not slip past validation into an opaque ctypes TypeError.
allowed_ids = AllowedTokenIds()
with pytest.raises(AssertionError):
allowed_ids.initialize(bad_ids)
Comment on lines +20 to +26

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect TypeError instead of AssertionError following the validation improvement in store_int_token_ids. Also, add a boolean value (e.g., [1, True]) to the parameterized test cases to verify that booleans are correctly rejected.

Suggested change
@pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]])
def test_allowed_token_ids_rejects_non_int(bad_ids):
# A non-int entry must fail with the explicit "all must be int" guard,
# not slip past validation into an opaque ctypes TypeError.
allowed_ids = AllowedTokenIds()
with pytest.raises(AssertionError):
allowed_ids.initialize(bad_ids)
@pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1], [1, True]])
def test_allowed_token_ids_rejects_non_int(bad_ids):
# A non-int entry must fail with the explicit "all must be int" guard,
# not slip past validation into an opaque ctypes TypeError.
allowed_ids = AllowedTokenIds()
with pytest.raises(TypeError):
allowed_ids.initialize(bad_ids)



def test_allowed_token_ids_rejects_too_many():
allowed_ids = AllowedTokenIds()
with pytest.raises(AssertionError):
allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1))
Comment on lines +29 to +32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect ValueError instead of AssertionError for length overflow validation.

Suggested change
def test_allowed_token_ids_rejects_too_many():
allowed_ids = AllowedTokenIds()
with pytest.raises(AssertionError):
allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1))
def test_allowed_token_ids_rejects_too_many():
allowed_ids = AllowedTokenIds()
with pytest.raises(ValueError):
allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1))



def test_invalid_token_ids_accepts_valid_ints():
invalid_ids = InvalidTokenIds()
invalid_ids.initialize([4, 5, 6])
assert invalid_ids.size == 3
assert invalid_ids.to_list() == [4, 5, 6]


@pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]])
def test_invalid_token_ids_rejects_non_int(bad_ids):
invalid_ids = InvalidTokenIds()
with pytest.raises(AssertionError):
invalid_ids.initialize(bad_ids)
Comment on lines +42 to +46

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect TypeError instead of AssertionError and add a boolean test case (e.g., [4, False]).

Suggested change
@pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]])
def test_invalid_token_ids_rejects_non_int(bad_ids):
invalid_ids = InvalidTokenIds()
with pytest.raises(AssertionError):
invalid_ids.initialize(bad_ids)
@pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0], [4, False]])
def test_invalid_token_ids_rejects_non_int(bad_ids):
invalid_ids = InvalidTokenIds()
with pytest.raises(TypeError):
invalid_ids.initialize(bad_ids)



def test_invalid_token_ids_rejects_too_many():
invalid_ids = InvalidTokenIds()
with pytest.raises(AssertionError):
invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1))
Comment on lines +49 to +52

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect ValueError instead of AssertionError.

Suggested change
def test_invalid_token_ids_rejects_too_many():
invalid_ids = InvalidTokenIds()
with pytest.raises(AssertionError):
invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1))
def test_invalid_token_ids_rejects_too_many():
invalid_ids = InvalidTokenIds()
with pytest.raises(ValueError):
invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1))



def test_stop_sequence_rejects_non_int():
seq = StopSequence()
with pytest.raises(AssertionError):
seq.initialize([1, "2"])
Comment on lines +55 to +58

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect TypeError instead of AssertionError.

Suggested change
def test_stop_sequence_rejects_non_int():
seq = StopSequence()
with pytest.raises(AssertionError):
seq.initialize([1, "2"])
def test_stop_sequence_rejects_non_int():
seq = StopSequence()
with pytest.raises(TypeError):
seq.initialize([1, "2"])



def test_store_int_token_ids_returns_size_and_writes_buffer():
import ctypes

buf = (ctypes.c_int * 8)()
size = store_int_token_ids(buf, [7, 8, 9], 8, "test ids")
assert size == 3
assert list(buf[:size]) == [7, 8, 9]


def test_store_int_token_ids_rejects_overflow():
import ctypes

buf = (ctypes.c_int * 2)()
with pytest.raises(AssertionError):
store_int_token_ids(buf, [1, 2, 3], 2, "test ids")
Comment on lines +70 to +75

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the test to expect ValueError instead of AssertionError.

Suggested change
def test_store_int_token_ids_rejects_overflow():
import ctypes
buf = (ctypes.c_int * 2)()
with pytest.raises(AssertionError):
store_int_token_ids(buf, [1, 2, 3], 2, "test ids")
def test_store_int_token_ids_rejects_overflow():
import ctypes
buf = (ctypes.c_int * 2)()
with pytest.raises(ValueError):
store_int_token_ids(buf, [1, 2, 3], 2, "test ids")



if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading