diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index dfcb2f8d9e..093c0c5681 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1758,6 +1758,10 @@ def __init__(self): self.parameter_regex = re.compile( r"|(?=)|$)", re.DOTALL ) + self.parameter_stream_regex = re.compile( + r"]*)>(.*?)(|(?=)|$)", re.DOTALL + ) + self.function_start_regex = re.compile(r"]*)>", re.DOTALL) self._normal_text_buffer = "" def has_tool_call(self, text: str) -> bool: @@ -1782,8 +1786,7 @@ def _convert_param_value(self, value: str, param_name: str, param_config: Dict, if param_name not in param_config: return value - prop = param_config.get(param_name, {}) - param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + param_type = self._get_qwen3_param_type(param_name, param_config) if param_type in ("string", "str", "enum"): return value @@ -1833,12 +1836,7 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional except ValueError: continue param_name = match[:idx].strip() - param_value = match[idx + 1 :] - # Strip leading/trailing newlines from value - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] + param_value = self._strip_value_newlines(match[idx + 1 :]) param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) @@ -1848,47 +1846,120 @@ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional parameters=json.dumps(param_dict, ensure_ascii=False), ) - def _build_partial_arguments_json(self, func_name: str, partial_body: str, tools: List[Tool]) -> Optional[str]: - """Build the current argument JSON from a partial XML tool-call body.""" - param_matches = self.parameter_regex.findall(partial_body) - if not param_matches: - return None + def _get_qwen3_param_type(self, param_name: str, param_config: Dict) -> str: + prop = param_config.get(param_name, {}) + return str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + def _strip_value_newlines(self, value: str) -> str: + """Strip the single leading/trailing newline the Qwen3 template wraps each value in.""" + if value.startswith("\n"): + value = value[1:] + if value.endswith("\n"): + value = value[:-1] + return value + def _strip_partial_xml_suffix(self, value: str) -> str: + for token in ("", "", self.eot_token): + max_len = min(len(value), len(token) - 1) + for suffix_len in range(max_len, 0, -1): + if token.startswith(value[-suffix_len:]): + return value[:-suffix_len] + return value + + def _build_streaming_arguments_json( + self, + func_name: str, + partial_body: str, + tools: List[Tool], + close_object: bool = False, + ) -> Optional[str]: + """Build a monotonic JSON arguments prefix for XML tool-call streaming. + + The result is always a byte-exact prefix of json.dumps(final_arguments) so the + serving layer (api_openai.py) can reconcile the streamed args at stream end. + String values stream character-by-character (a string prefix stays a prefix); + non-string values are only emitted once their arrives, because a + partial number/array/bool is not guaranteed to be a prefix of its json.dumps form. + """ param_config = self._get_param_config(func_name, tools) - param_dict = {} - has_visible_value = False + parts = ["{"] + has_param = False - for match in param_matches: - try: - idx = match.index(">") - except ValueError: + for match in self.parameter_stream_regex.finditer(partial_body): + param_name = match.group(1).strip() + if not param_name: continue - param_name = match[:idx].strip() - param_value = match[idx + 1 :] - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - if param_value.strip(): - has_visible_value = True - elif ( - f"" in partial_body - and f"{param_value}" in partial_body - ): - # Closed empty-string parameter. We can safely emit it. - has_visible_value = True + # The value is complete only when an explicit closed it, or a + # sibling follows. Otherwise it is still streaming. + # (We can't key off match.end()==len: `$` matches before a trailing newline, + # and the template wraps every value in one, which would look "complete".) + rest = partial_body[match.end() :] + value_open = ( + match.group(3) != "" + and not rest.startswith("") + ) + + if has_param: + parts.append(", ") + parts.append(json.dumps(param_name, ensure_ascii=False)) + parts.append(": ") + has_param = True + + param_type = self._get_qwen3_param_type(param_name, param_config) + is_string = param_type in ("string", "str", "enum") + + if value_open: + # In-progress (and therefore last) parameter. + if is_string: + value = self._strip_value_newlines(self._strip_partial_xml_suffix(match.group(2))) + # Drop the closing quote so the stream stays an extendable prefix. + parts.append(json.dumps(value, ensure_ascii=False)[:-1]) + # Non-string values cannot be emitted as a safe partial prefix, so stop + # after the key and wait for the value to close. + return "".join(parts) + + value = self._strip_value_newlines(match.group(2)) + if is_string: + parts.append(json.dumps(value, ensure_ascii=False)) else: - # Parameter tag is present but its value has not started streaming yet. - continue + converted = self._convert_param_value(value, param_name, param_config, func_name) + parts.append(json.dumps(converted, ensure_ascii=False)) - param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + if not has_param: + return "{}" if close_object else None - if not param_dict and not has_visible_value: - return None + if close_object: + parts.append("}") - return json.dumps(param_dict, ensure_ascii=False) + return "".join(parts) + + def _ensure_qwen3_stream_state(self, tool_index: int) -> None: + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + def _append_qwen3_arguments_delta(self, calls: List[ToolCallItem], tool_index: int, current_args_json: str) -> None: + sent_args = self.streamed_args_for_tool[tool_index] + if not current_args_json.startswith(sent_args): + logger.warning( + "Qwen3-Coder streaming arguments are not monotonic for tool index %s; skip delta.", + tool_index, + ) + return + + argument_diff = current_args_json[len(sent_args) :] + if argument_diff: + calls.append( + ToolCallItem( + tool_index=tool_index, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[tool_index] += argument_diff def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: idx = text.find(self.bot_token) @@ -1941,23 +2012,100 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami self._buffer = current_text[tool_call_start:] current_text = self._buffer + if self.current_tool_id == -1: + self.current_tool_id = 0 + + self._ensure_qwen3_stream_state(self.current_tool_id) + + function_match = self.function_start_regex.search(current_text) + if not function_match: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + func_name = function_match.group(1).strip() eot_pos = current_text.find(self.eot_token) + func_defined = func_name in self._tool_indices + + # Undefined function whose block has not finished yet: wait for more text + # (the block may also contain a valid function we shouldn't drop). + if not func_defined and eot_pos == -1: + return StreamingParseResult(normal_text=normal_text, calls=calls) + + if func_defined: + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + # The function body is complete once we hit either or the + # enclosing ; treating eot as an implicit close lets us emit + # the closing '}' inside the same args delta (so the serving layer's + # stop-time reconciliation sees one delta, not a separate trailing '}'). + function_close_pos = current_text.find("", function_match.end()) + if function_close_pos != -1 and (eot_pos == -1 or function_close_pos < eot_pos): + partial_end = function_close_pos + close_object = True + elif eot_pos != -1: + partial_end = eot_pos + close_object = True + else: + partial_end = len(current_text) + close_object = False + partial_body = current_text[function_match.end() : partial_end] + current_args_json = self._build_streaming_arguments_json( + func_name, + partial_body, + tools, + close_object=close_object, + ) + if current_args_json: + self._append_qwen3_arguments_delta(calls, self.current_tool_id, current_args_json) + if eot_pos == -1: return StreamingParseResult(normal_text=normal_text, calls=calls) complete_block = current_text[: eot_pos + len(self.eot_token)] func_matches = self.function_regex.findall(complete_block) - if self.current_tool_id == -1: - self.current_tool_id = 0 - + # Flush every completed function in the block. _parse_function_call returns + # None for undefined ones, so they are skipped without advancing the index. for match in func_matches: func_str = match[0] if match[0] else match[1] item = self._parse_function_call(func_str, tools) - if item: - item.tool_index = self.current_tool_id - calls.append(item) - self.current_tool_id += 1 + if not item: + continue + completed_tool_id = self.current_tool_id + self._ensure_qwen3_stream_state(completed_tool_id) + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=completed_tool_id, + name=item.name, + parameters="", + ) + ) + self.current_tool_name_sent = True + try: + parsed_args = json.loads(item.parameters) + except json.JSONDecodeError: + parsed_args = {} + self.prev_tool_call_arr[completed_tool_id] = { + "name": item.name, + "arguments": parsed_args, + } + sent_args = self.streamed_args_for_tool[completed_tool_id] + if item.parameters.startswith(sent_args): + self._append_qwen3_arguments_delta(calls, completed_tool_id, item.parameters) + self.current_tool_id += 1 + self.current_tool_name_sent = False self._buffer = current_text[eot_pos + len(self.eot_token) :].lstrip() diff --git a/unit_tests/server/test_qwen3_coder_stream_fc.py b/unit_tests/server/test_qwen3_coder_stream_fc.py new file mode 100644 index 0000000000..f72a642b11 --- /dev/null +++ b/unit_tests/server/test_qwen3_coder_stream_fc.py @@ -0,0 +1,194 @@ +"""Unit tests for Qwen3-Coder XML streaming tool-call parsing. + +These drive ``Qwen3CoderDetector.parse_streaming_increment`` directly (no server), +reassembling the per-tool argument string exactly the way ``api_openai.py`` does for +streamed responses, including the ``finish_reason == "stop"`` reconciliation. The key +invariant under test: the reassembled arguments are always valid JSON equal to what the +one-shot ``detect_and_parse`` would produce, for every chunk boundary. +""" + +import json +import pytest + +from lightllm.server.api_models import Function, Tool +from lightllm.server.function_call_parser import Qwen3CoderDetector + +CHUNK_SIZES = [1, 2, 3, 5, 13, 10_000] + + +def _tool(name, properties): + return Tool( + type="function", + function=Function(name=name, description="", parameters={"type": "object", "properties": properties}), + ) + + +def _stream_and_reassemble(text, tools, chunk): + """Feed ``text`` to the detector in fixed-size chunks and rebuild the client-visible + tool calls the way api_openai.py stream_results does (stop-rewrite on the last chunk).""" + det = Qwen3CoderDetector() + chunks = [text[i : i + chunk] for i in range(0, len(text), chunk)] + per_tool = {} + for ci, piece in enumerate(chunks): + result = det.parse_streaming_increment(piece, tools) + is_last = ci == len(chunks) - 1 + for call in result.calls: + ti = call.tool_index + per_tool.setdefault(ti, {"name": None, "args": ""}) + if call.name is not None: + per_tool[ti]["name"] = call.name + params = call.parameters + if is_last and params: + # Mirror api_openai.py:559-575 (REPLACE semantics). + latest_delta_len = len(params) + expected = json.dumps(det.prev_tool_call_arr[ti].get("arguments", {}), ensure_ascii=False) + actual = det.streamed_args_for_tool[ti] + if latest_delta_len > 0: + actual = actual[:-latest_delta_len] + params = expected.replace(actual, "", 1) + if params: + per_tool[ti]["args"] += params + return det, per_tool + + +def _assert_tool_calls(text, tools, expected, chunk_sizes=CHUNK_SIZES): + """expected: {tool_index: (name, args_dict)}.""" + for chunk in chunk_sizes: + det, per_tool = _stream_and_reassemble(text, tools, chunk) + assert len(per_tool) == len(expected), f"chunk={chunk}: tool count {len(per_tool)} != {len(expected)}" + for ti, (name, args) in expected.items(): + got = per_tool[ti] + assert got["name"] == name, f"chunk={chunk}: tool {ti} name {got['name']!r} != {name!r}" + parsed = json.loads(got["args"]) # must be valid JSON + assert parsed == args, f"chunk={chunk}: tool {ti} args {parsed!r} != {args!r}" + + +@pytest.mark.parametrize("chunk", CHUNK_SIZES) +def test_single_string_param(chunk): + text = ( + "\n\n\n" + "San Francisco\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("get_weather", {"location": {"type": "string"}})], + {0: ("get_weather", {"location": "San Francisco"})}, + [chunk], + ) + + +def test_array_param_compact_spacing(): + # Regression: a non-string value whose raw text ("[1,2]") differs from json.dumps + # ("[1, 2]") used to break the streamed-args prefix invariant -> duplicated/invalid JSON. + text = "\n\n\n[1,2]\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"nums": {"type": "array"}})], {0: ("calc", {"nums": [1, 2]})}) + + +def test_number_param_reformatted(): + # "1.0" is parsed to int 1 and json.dumps'd as "1"; the stream must agree. + text = "\n\n\n1.0\n\n\n" + _assert_tool_calls(text, [_tool("calc", {"v": {"type": "number"}})], {0: ("calc", {"v": 1})}) + + +def test_boolean_param(): + text = "\n\n\ntrue\n\n\n" + _assert_tool_calls(text, [_tool("set", {"flag": {"type": "boolean"}})], {0: ("set", {"flag": True})}) + + +def test_object_param(): + text = '\n\n\n{"a":1,"b":[2,3]}\n\n\n' + _assert_tool_calls(text, [_tool("f", {"cfg": {"type": "object"}})], {0: ("f", {"cfg": {"a": 1, "b": [2, 3]}})}) + + +def test_two_params_mixed_types(): + text = ( + "\n\n\nNYC\n\n" + "\n3\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("f", {"city": {"type": "string"}, "days": {"type": "integer"}})], + {0: ("f", {"city": "NYC", "days": 3})}, + ) + + +def test_multiline_string_value(): + text = ( + "\n\n\nline1\nline2\nline3\n" "\n\n" + ) + _assert_tool_calls(text, [_tool("f", {"code": {"type": "string"}})], {0: ("f", {"code": "line1\nline2\nline3"})}) + + +def test_string_with_json_special_chars(): + text = '\n\n\nsay "hi"\\path\n\n\n' + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": 'say "hi"\\path'})}) + + +def test_empty_string_value(): + text = "\n\n\n\n\n\n" + _assert_tool_calls(text, [_tool("f", {"s": {"type": "string"}})], {0: ("f", {"s": ""})}) + + +def test_no_param_function(): + text = "\n\n\n" + _assert_tool_calls(text, [_tool("ping", {})], {0: ("ping", {})}) + + +def test_two_separate_tool_call_blocks(): + text = ( + "\n\n\nhi\n\n\n\n" + "\n\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_two_functions_in_one_block(): + # Regression: used to raise IndexError on the second function in a single block. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls( + text, + [_tool("a", {"x": {"type": "string"}}), _tool("b", {"y": {"type": "string"}})], + {0: ("a", {"x": "hi"}), 1: ("b", {"y": "yo"})}, + ) + + +def test_undefined_then_valid_in_same_block(): + # Regression: an undefined first function used to discard the whole block, dropping + # the valid call that followed it. + text = ( + "\n\n\nhi\n\n\n" + "\n\nyo\n\n\n" + ) + _assert_tool_calls(text, [_tool("valid", {"y": {"type": "string"}})], {0: ("valid", {"y": "yo"})}) + + +def test_truncated_call_missing_function_close(): + # Regression: a typed value with no before used to leave the + # streamed args unterminated (missing closing brace). + text = "\n\n\n0.50\n\n" + _assert_tool_calls(text, [_tool("calc", {"x": {"type": "number"}})], {0: ("calc", {"x": 0.5})}) + + +def test_streaming_matches_non_stream(): + # The reassembled streamed args must equal the one-shot detect_and_parse output. + tools = [_tool("f", {"city": {"type": "string"}, "n": {"type": "integer"}, "tags": {"type": "array"}})] + text = ( + "\n\n\nLondon\n\n" + '\n7\n\n\n["a","b"]\n\n\n' + ) + oneshot = Qwen3CoderDetector().detect_and_parse(text, tools) + expected_args = json.loads(oneshot.calls[0].parameters) + _assert_tool_calls(text, tools, {0: ("f", expected_args)}) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"]))