-
Notifications
You must be signed in to change notification settings - Fork 334
fix(sampling_params): validate token-id input types via a shared helper #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,79 @@ | ||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||
| from lightllm.server.core.objs.sampling_params import ( | ||||||||||||||||||||||||||||||
| StopSequence, | ||||||||||||||||||||||||||||||
| AllowedTokenIds, | ||||||||||||||||||||||||||||||
| InvalidTokenIds, | ||||||||||||||||||||||||||||||
| store_int_token_ids, | ||||||||||||||||||||||||||||||
| STOP_SEQUENCE_MAX_LENGTH, | ||||||||||||||||||||||||||||||
| ALLOWED_TOKEN_IDS_MAX_LENGTH, | ||||||||||||||||||||||||||||||
| INVALID_TOKEN_IDS_MAX_LENGTH, | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_allowed_token_ids_accepts_valid_ints(): | ||||||||||||||||||||||||||||||
| allowed_ids = AllowedTokenIds() | ||||||||||||||||||||||||||||||
| allowed_ids.initialize([1, 2, 3]) | ||||||||||||||||||||||||||||||
| assert allowed_ids.size == 3 | ||||||||||||||||||||||||||||||
| assert allowed_ids.to_list() == [1, 2, 3] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("bad_ids", [[1, 2, "3"], [1, 2.5], [None, 1]]) | ||||||||||||||||||||||||||||||
| def test_allowed_token_ids_rejects_non_int(bad_ids): | ||||||||||||||||||||||||||||||
| # A non-int entry must fail with the explicit "all must be int" guard, | ||||||||||||||||||||||||||||||
| # not slip past validation into an opaque ctypes TypeError. | ||||||||||||||||||||||||||||||
| allowed_ids = AllowedTokenIds() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| allowed_ids.initialize(bad_ids) | ||||||||||||||||||||||||||||||
|
Comment on lines
+20
to
+26
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_allowed_token_ids_rejects_too_many(): | ||||||||||||||||||||||||||||||
| allowed_ids = AllowedTokenIds() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| allowed_ids.initialize([1] * (ALLOWED_TOKEN_IDS_MAX_LENGTH + 1)) | ||||||||||||||||||||||||||||||
|
Comment on lines
+29
to
+32
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_invalid_token_ids_accepts_valid_ints(): | ||||||||||||||||||||||||||||||
| invalid_ids = InvalidTokenIds() | ||||||||||||||||||||||||||||||
| invalid_ids.initialize([4, 5, 6]) | ||||||||||||||||||||||||||||||
| assert invalid_ids.size == 3 | ||||||||||||||||||||||||||||||
| assert invalid_ids.to_list() == [4, 5, 6] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("bad_ids", [[4, "5"], [4, 5.0]]) | ||||||||||||||||||||||||||||||
| def test_invalid_token_ids_rejects_non_int(bad_ids): | ||||||||||||||||||||||||||||||
| invalid_ids = InvalidTokenIds() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| invalid_ids.initialize(bad_ids) | ||||||||||||||||||||||||||||||
|
Comment on lines
+42
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_invalid_token_ids_rejects_too_many(): | ||||||||||||||||||||||||||||||
| invalid_ids = InvalidTokenIds() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| invalid_ids.initialize([1] * (INVALID_TOKEN_IDS_MAX_LENGTH + 1)) | ||||||||||||||||||||||||||||||
|
Comment on lines
+49
to
+52
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_stop_sequence_rejects_non_int(): | ||||||||||||||||||||||||||||||
| seq = StopSequence() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| seq.initialize([1, "2"]) | ||||||||||||||||||||||||||||||
|
Comment on lines
+55
to
+58
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_store_int_token_ids_returns_size_and_writes_buffer(): | ||||||||||||||||||||||||||||||
| import ctypes | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| buf = (ctypes.c_int * 8)() | ||||||||||||||||||||||||||||||
| size = store_int_token_ids(buf, [7, 8, 9], 8, "test ids") | ||||||||||||||||||||||||||||||
| assert size == 3 | ||||||||||||||||||||||||||||||
| assert list(buf[:size]) == [7, 8, 9] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_store_int_token_ids_rejects_overflow(): | ||||||||||||||||||||||||||||||
| import ctypes | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| buf = (ctypes.c_int * 2)() | ||||||||||||||||||||||||||||||
| with pytest.raises(AssertionError): | ||||||||||||||||||||||||||||||
| store_int_token_ids(buf, [1, 2, 3], 2, "test ids") | ||||||||||||||||||||||||||||||
|
Comment on lines
+70
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test to expect
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||
| pytest.main([__file__, "-v"]) | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
assertstatements for runtime validation of input arguments is discouraged because assertions can be globally disabled in Python when run with optimization flags (e.g.,python -O). Instead, standard exceptions likeValueErrorandTypeErrorshould be raised.Additionally,
isinstance(e, int)returnsTruefor boolean values (e.g.,True,False) becauseboolis a subclass ofintin Python. To strictly validate that the token IDs are integers and not booleans, we should usetype(e) is int.References