From 202acdce11c03066808a6722efc2020e3978e94b Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 5 Jun 2026 05:27:43 +0800 Subject: [PATCH] fix: load artifacts from workflow text responses --- src/google/adk/tools/load_artifacts_tool.py | 116 ++++++++++++------ .../tools/test_load_artifacts_tool.py | 69 +++++++++++ 2 files changed, 146 insertions(+), 39 deletions(-) diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py index ec717bad4c..ffdada5908 100644 --- a/src/google/adk/tools/load_artifacts_tool.py +++ b/src/google/adk/tools/load_artifacts_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import ast import base64 import binascii import json @@ -46,6 +47,7 @@ from .tool_context import ToolContext logger = logging.getLogger('google_adk.' + __name__) +_LOAD_ARTIFACTS_TEXT_MARKER = '`load_artifacts` tool returned result:' def _normalize_mime_type(mime_type: str | None) -> str | None: @@ -121,6 +123,47 @@ def _as_safe_part_for_llm( ) +def _artifact_names_from_response(response: Any) -> list[str]: + if not isinstance(response, dict): + return [] + + artifact_names = response.get('artifact_names', []) + if isinstance(artifact_names, str): + return [artifact_names] + if not isinstance(artifact_names, list): + return [] + return [name for name in artifact_names if isinstance(name, str)] + + +def _artifact_names_from_text_response(text: str | None) -> list[str]: + if not text or _LOAD_ARTIFACTS_TEXT_MARKER not in text: + return [] + + payload = text.split(_LOAD_ARTIFACTS_TEXT_MARKER, 1)[1].strip() + try: + response = ast.literal_eval(payload) + except (SyntaxError, ValueError) as exc: + logger.debug('Could not parse load_artifacts text response: %s', exc) + return [] + + return _artifact_names_from_response(response) + + +def _requested_artifact_names(content: types.Content) -> list[str]: + artifact_names: list[str] = [] + for part in content.parts or []: + function_response = part.function_response + if function_response and function_response.name == 'load_artifacts': + artifact_names.extend( + _artifact_names_from_response(function_response.response or {}) + ) + continue + + artifact_names.extend(_artifact_names_from_text_response(part.text)) + + return artifact_names + + class LoadArtifactsTool(BaseTool): """A tool that loads the artifacts and adds them to the session.""" @@ -210,46 +253,41 @@ async def _append_artifacts_to_llm_request( # Attach the content of the artifacts if the model requests them. # This only adds the content to the model request, instead of the session. if llm_request.contents and llm_request.contents[-1].parts: - function_response = llm_request.contents[-1].parts[0].function_response - if function_response and function_response.name == 'load_artifacts': - response = function_response.response or {} - artifact_names = response.get('artifact_names', []) - for artifact_name in artifact_names: - # Try session-scoped first (default behavior) - artifact = await tool_context.load_artifact(artifact_name) - - # If not found and name doesn't already have user: prefix, - # try cross-session artifacts with user: prefix - if artifact is None and not artifact_name.startswith('user:'): - prefixed_name = f'user:{artifact_name}' - artifact = await tool_context.load_artifact(prefixed_name) - - if artifact is None: - logger.warning('Artifact "%s" not found, skipping', artifact_name) - continue - - artifact_part = _as_safe_part_for_llm(artifact, artifact_name) - if artifact_part is not artifact: - mime_type = ( - artifact.inline_data.mime_type if artifact.inline_data else None - ) - logger.debug( - 'Converted artifact "%s" (mime_type=%s) to text Part', - artifact_name, - mime_type, - ) - - llm_request.contents.append( - types.Content( - role='user', - parts=[ - types.Part.from_text( - text=f'Artifact {artifact_name} is:' - ), - artifact_part, - ], - ) + artifact_names = _requested_artifact_names(llm_request.contents[-1]) + for artifact_name in artifact_names: + # Try session-scoped first (default behavior) + artifact = await tool_context.load_artifact(artifact_name) + + # If not found and name doesn't already have user: prefix, + # try cross-session artifacts with user: prefix + if artifact is None and not artifact_name.startswith('user:'): + prefixed_name = f'user:{artifact_name}' + artifact = await tool_context.load_artifact(prefixed_name) + + if artifact is None: + logger.warning('Artifact "%s" not found, skipping', artifact_name) + continue + + artifact_part = _as_safe_part_for_llm(artifact, artifact_name) + if artifact_part is not artifact: + mime_type = ( + artifact.inline_data.mime_type if artifact.inline_data else None + ) + logger.debug( + 'Converted artifact "%s" (mime_type=%s) to text Part', + artifact_name, + mime_type, ) + llm_request.contents.append( + types.Content( + role='user', + parts=[ + types.Part.from_text(text=f'Artifact {artifact_name} is:'), + artifact_part, + ], + ) + ) + load_artifacts_tool = LoadArtifactsTool() diff --git a/tests/unittests/tools/test_load_artifacts_tool.py b/tests/unittests/tools/test_load_artifacts_tool.py index 6a420574f0..a800474d17 100644 --- a/tests/unittests/tools/test_load_artifacts_tool.py +++ b/tests/unittests/tools/test_load_artifacts_tool.py @@ -144,6 +144,75 @@ async def test_load_artifacts_keeps_supported_mime_types(): assert artifact_part.inline_data.mime_type == 'application/pdf' +@mark.asyncio +async def test_load_artifacts_reads_workflow_text_response(): + """Workflow context can stringify tool responses from other nodes.""" + artifact_name = 'invoice.txt' + artifact = types.Part.from_text(text='invoice total: 42') + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_text(text='For context:'), + types.Part.from_text( + text=( + '[workflow_node] `load_artifacts` tool returned' + " result: {'artifact_names': ['invoice.txt']," + " 'status': 'ok'}" + ) + ), + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.contents[-1].parts[0].text == ( + f'Artifact {artifact_name} is:' + ) + assert llm_request.contents[-1].parts[1].text == 'invoice total: 42' + + +@mark.asyncio +async def test_load_artifacts_checks_all_function_response_parts(): + """The load_artifacts response may not be the first part in a turn.""" + artifact_name = 'notes.txt' + artifact = types.Part.from_text(text='important notes') + + tool_context = _StubToolContext({artifact_name: artifact}) + llm_request = LlmRequest( + contents=[ + types.Content( + role='user', + parts=[ + types.Part.from_text(text='Done.'), + types.Part( + function_response=types.FunctionResponse( + name='load_artifacts', + response={'artifact_names': [artifact_name]}, + ) + ), + ], + ) + ] + ) + + await load_artifacts_tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.contents[-1].parts[0].text == ( + f'Artifact {artifact_name} is:' + ) + assert llm_request.contents[-1].parts[1].text == 'important notes' + + def test_maybe_base64_to_bytes_decodes_standard_base64(): """Standard base64 encoded strings are decoded correctly.""" original = b'hello world'