diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index ddd765e654..a9145c2e85 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -31,7 +31,9 @@ from ._serialization import SerializationMixin from ._tools import ( FunctionInvocationConfiguration, + FunctionInvocationLayer, ToolTypes, + normalize_function_invocation_configuration, ) from ._types import ( ChatResponse, @@ -647,7 +649,15 @@ def as_agent( "additional_properties": dict(additional_properties) if additional_properties is not None else None, } if function_invocation_configuration is not None: - agent_kwargs["function_invocation_configuration"] = function_invocation_configuration + if isinstance(self, FunctionInvocationLayer): + self.function_invocation_configuration = normalize_function_invocation_configuration( + function_invocation_configuration + ) + else: + logger.warning( + "function_invocation_configuration was provided, but the chat client does not support " + "function invoking." + ) return Agent(**agent_kwargs) diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index 7657993d56..b7248ef941 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -15,6 +15,7 @@ SlidingWindowStrategy, SupportsChatGetResponse, TruncationStrategy, + normalize_function_invocation_configuration, ) @@ -58,6 +59,18 @@ def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_ba assert agent.additional_properties == {"team": "core"} +def test_base_client_as_agent_applies_function_invocation_configuration_to_client( + chat_client_base: SupportsChatGetResponse, +) -> None: + config = normalize_function_invocation_configuration({"max_iterations": 1, "include_detailed_errors": True}) + + agent = chat_client_base.as_agent(function_invocation_configuration=config) + + assert agent.client is chat_client_base + assert chat_client_base.function_invocation_configuration["max_iterations"] == 1 # type: ignore[attr-defined] + assert chat_client_base.function_invocation_configuration["include_detailed_errors"] is True # type: ignore[attr-defined] + + async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None: async def fake_inner_get_response(**kwargs): assert kwargs["trace_id"] == "trace-123"