Skip to content

fix stream fc for qwen3_coder#1364

Open
shihaobai wants to merge 2 commits into
mainfrom
stream_fc
Open

fix stream fc for qwen3_coder#1364
shihaobai wants to merge 2 commits into
mainfrom
stream_fc

Conversation

@shihaobai

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces XML tool-call streaming parsing support for Qwen3-Coder in lightllm/server/function_call_parser.py by adding regexes, helper methods for monotonic JSON argument building, and updating the streaming increment parser. The review feedback identifies three critical issues: first, implicitly closed parameters are incorrectly treated as unclosed, which prematurely breaks the streaming loop; second, parallel tool calls can trigger an IndexError because state arrays are not dynamically resized as the tool ID increments; and third, fallback logic for non-string parameter conversion can output invalid JSON instead of properly serialized values.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread lightllm/server/function_call_parser.py Outdated
Comment on lines +1931 to +1933
closed = match.group(3) == "</parameter>"
if not closed:
param_value = self._strip_partial_xml_suffix(param_value)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In streaming scenarios with multiple parameters, a parameter can be implicitly closed when a subsequent parameter starts (e.g., matching (?=<parameter=)) or when the function block closes (e.g., matching (?=</function>)). In these cases, match.group(3) is empty, so closed is incorrectly evaluated as False. This causes the loop to break early, completely ignoring any subsequent parameters. Checking the remaining text for lookahead matches or close_object ensures that implicitly closed parameters are correctly processed and the loop continues to subsequent parameters.

            remaining = partial_body[match.end():]
            closed = (
                match.group(3) == "</parameter>"
                or remaining.startswith("<parameter=")
                or remaining.startswith("</function>")
                or close_object
            )
            if not closed:
                param_value = self._strip_partial_xml_suffix(param_value)

Comment thread lightllm/server/function_call_parser.py Outdated
Comment on lines +2111 to +2119
completed_tool_id = self.current_tool_id
try:
parsed_args = json.loads(item.parameters)
except json.JSONDecodeError:
parsed_args = {}
self.prev_tool_call_arr[completed_tool_id] = {
"name": item.name,
"arguments": parsed_args,
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When processing parallel tool calls, self.current_tool_id is incremented inside the func_matches loop. For any subsequent matches, completed_tool_id will exceed the current bounds of self.prev_tool_call_arr and self.streamed_args_for_tool, leading to an IndexError and crashing the server. Calling self._ensure_qwen3_stream_state(completed_tool_id) inside the loop ensures the state arrays are dynamically resized to accommodate parallel tool calls.

Suggested change
completed_tool_id = self.current_tool_id
try:
parsed_args = json.loads(item.parameters)
except json.JSONDecodeError:
parsed_args = {}
self.prev_tool_call_arr[completed_tool_id] = {
"name": item.name,
"arguments": parsed_args,
}
completed_tool_id = self.current_tool_id
self._ensure_qwen3_stream_state(completed_tool_id)
try:
parsed_args = json.loads(item.parameters)
except json.JSONDecodeError:
parsed_args = {}
self.prev_tool_call_arr[completed_tool_id] = {
"name": item.name,
"arguments": parsed_args,
}

Comment thread lightllm/server/function_call_parser.py Outdated
Comment on lines +1952 to +1958
elif closed:
converted = self._convert_param_value(param_value, param_name, param_config, func_name)
converted_json = json.dumps(converted, ensure_ascii=False)
if converted_json.startswith(param_value):
parts.append(converted_json)
else:
parts.append(param_value)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When closed is True, the parameter value is complete. If the parameter type is non-string (e.g., boolean, integer, object), and the value fails to convert or has a different representation (like True vs true), the startswith check will fail and append the raw param_value directly. This can result in invalid JSON (e.g., {"param": True} or {"param": abc}). Directly appending json.dumps(converted, ensure_ascii=False) guarantees valid JSON output.

Suggested change
elif closed:
converted = self._convert_param_value(param_value, param_name, param_config, func_name)
converted_json = json.dumps(converted, ensure_ascii=False)
if converted_json.startswith(param_value):
parts.append(converted_json)
else:
parts.append(param_value)
elif closed:
converted = self._convert_param_value(param_value, param_name, param_config, func_name)
parts.append(json.dumps(converted, ensure_ascii=False))

Keep streamed_args_for_tool a byte-exact prefix of json.dumps(arguments) so
the serving-layer stop-time reconciliation can't produce duplicated/invalid
JSON. String values stream incrementally; non-string values are emitted only
once </parameter> arrives (a partial number/array/bool isn't a prefix of its
json.dumps form). Detect "value still streaming" via the terminator rather
than match position ($ matches before the template's trailing newline).

Also:
- don't crash (IndexError) on >=2 <function=> in one <tool_call> block, and
  emit a name head for each; flush via _ensure_qwen3_stream_state.
- an undefined first function no longer discards the whole block (a valid
  function after it is still emitted).
- treat </tool_call> as an implicit function close so the closing '}' rides
  in the same args delta (no separate trailing-'}' delta to double-count).
- drop dead _build_partial_arguments_json; dedup newline-strip / param-type
  helpers.

Add unit_tests/server/test_qwen3_coder_stream_fc.py covering these across
chunk boundaries.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants