diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index b5f6fcf91..9d57049ba 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -232,6 +232,7 @@ def format_request( "name": tool_spec["name"], "description": tool_spec["description"], "input_schema": tool_spec["inputSchema"]["json"], + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), } for tool_spec in tool_specs or [] ], diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bab4031ed..9a886c14d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -255,6 +255,7 @@ def _format_request( "name": tool_spec["name"], "description": tool_spec["description"], "inputSchema": tool_spec["inputSchema"], + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), } } for tool_spec in tool_specs diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 2b217ad91..a2a40a8a6 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -461,6 +461,7 @@ def format_request( "name": tool_spec["name"], "description": tool_spec["description"], "parameters": tool_spec["inputSchema"]["json"], + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), }, } for tool_spec in tool_specs or [] diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 0ace9645f..570a0e0e3 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -418,6 +418,7 @@ def _format_request( "name": tool_spec["name"], "description": tool_spec.get("description", ""), "parameters": tool_spec["inputSchema"]["json"], + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), } for tool_spec in tool_specs ] @@ -502,9 +503,7 @@ def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: ] @classmethod - def _format_request_message_content( - cls, content: ContentBlock, *, role: Role = "user" - ) -> dict[str, Any]: + def _format_request_message_content(cls, content: ContentBlock, *, role: Role = "user") -> dict[str, Any]: """Format an OpenAI compatible content block. Args: diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 9207df9b8..2a778bc31 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -725,6 +725,7 @@ def tool( inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + strict: bool | None = None, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system @@ -734,6 +735,7 @@ def tool( # type: ignore inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + strict: bool | None = None, ) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. @@ -762,6 +764,8 @@ def tool( # type: ignore context: When provided, places an object in the designated parameter. If True, the param name defaults to 'tool_context', or if an override is needed, set context equal to a string to designate the param name. + strict: Optional Boolean that ensures the model will only output tool calls containing parameters + that perfectly match the defined input schema. Returns: An AgentTool that also mimics the original function when invoked @@ -816,6 +820,8 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": tool_spec["description"] = description if inputSchema is not None: tool_spec["inputSchema"] = inputSchema + if strict is not None: + tool_spec["strict"] = strict tool_name = tool_spec.get("name", f.__name__) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 088c83bdb..8c7e5109f 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -30,12 +30,17 @@ class ToolSpec(TypedDict): outputSchema: Optional JSON Schema defining the expected output format. Note: Not all model providers support this field. Providers that don't support it should filter it out before sending to their API. + strict: Optional Boolean that ensures the model will only output tool calls + containing parameters that perfectly match the defined input schema. + Note: Not all model providers support this field. Providers that don't + support it should filter it out before sending to their API. """ description: str inputSchema: JSONSchema name: str outputSchema: NotRequired[JSONSchema] + strict: NotRequired[bool] class Tool(TypedDict): diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index c5aff8062..b80cd9d56 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -440,6 +440,31 @@ def test_format_request_tool_choice_auto(model, messages, model_id, max_tokens): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, model_id, max_tokens): + tool_specs = [ + {"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}, "strict": True} + ] + tru_request = model.format_request(messages, tool_specs) + + assert tru_request["tools"][0]["strict"] is True + + +def test_format_request_tool_specs_with_strict_false(model, messages, model_id, max_tokens): + tool_specs = [ + {"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}, "strict": False} + ] + tru_request = model.format_request(messages, tool_specs) + + assert tru_request["tools"][0]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tru_request = model.format_request(messages, tool_specs) + + assert "strict" not in tru_request["tools"][0] + + def test_format_request_tool_choice_any(model, messages, model_id, max_tokens): tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] tool_choice = {"any": {}} diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 89c4df70d..e074d9860 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -471,6 +471,44 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, model_id): + strict_tool_spec = { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + "strict": True, + } + tru_request = model._format_request(messages, tool_specs=[strict_tool_spec]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert tool_in_request["strict"] is True + + +def test_format_request_tool_specs_with_strict_false(model, messages, model_id): + strict_false_tool_spec = { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + "strict": False, + } + tru_request = model._format_request(messages, tool_specs=[strict_false_tool_spec]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert tool_in_request["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, model_id): + tool_spec_no_strict = { + "description": "description", + "name": "name", + "inputSchema": {"key": "val"}, + } + tru_request = model._format_request(messages, tool_specs=[tool_spec_no_strict]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_in_request + + def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): tool_choice = {"auto": {}} tru_request = model._format_request(messages, [tool_spec], tool_choice=tool_choice) diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 241c22b64..f806223f4 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -503,6 +503,44 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, system_prompt): + strict_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": True, + } + ] + tru_request = model.format_request(messages, strict_tool_specs, system_prompt) + + assert tru_request["tools"][0]["function"]["strict"] is True + + +def test_format_request_tool_specs_with_strict_false(model, messages, system_prompt): + strict_false_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": False, + } + ] + tru_request = model.format_request(messages, strict_false_tool_specs, system_prompt) + + assert tru_request["tools"][0]["function"]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + + assert "strict" not in tru_request["tools"][0]["function"] + + def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): tool_choice = {"auto": {}} tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 9c84f4ed4..1c1607345 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -348,6 +348,53 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, system_prompt): + strict_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": True, + } + ] + tru_request = model._format_request(messages, strict_tool_specs, system_prompt) + + assert tru_request["tools"][0]["strict"] is True + + +def test_format_request_tool_specs_with_strict_false(model, messages, system_prompt): + strict_false_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": False, + } + ] + tru_request = model._format_request(messages, strict_false_tool_specs, system_prompt) + + assert tru_request["tools"][0]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, system_prompt): + tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + } + ] + tru_request = model._format_request(messages, tool_specs, system_prompt) + + assert "strict" not in tru_request["tools"][0] + + @pytest.mark.parametrize( ("event", "exp_chunk"), [ diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index cc1158983..579bc0a5b 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -90,6 +90,33 @@ def test_tool_spec(identity_tool): assert tru_spec == exp_spec +def test_tool_spec_with_strict_true(): + @strands.tool(strict=True) + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert my_tool.tool_spec["strict"] is True + + +def test_tool_spec_with_strict_false(): + @strands.tool(strict=False) + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert my_tool.tool_spec["strict"] is False + + +def test_tool_spec_without_strict(): + @strands.tool + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert "strict" not in my_tool.tool_spec + + @pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) def test_tool_type(identity_tool): tru_type = identity_tool.tool_type