diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index cff7f987e6..3149f48cc1 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -35,6 +35,8 @@ class FuncSchema: """The signature of the function.""" takes_context: bool = False """Whether the function takes a RunContextWrapper argument (must be the first argument).""" + self_or_cls_skipped: bool = False + """Whether an unannotated self/cls first parameter was skipped during schema generation.""" strict_json_schema: bool = True """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" @@ -285,6 +287,7 @@ def function_schema( sig = inspect.signature(func) params = list(sig.parameters.items()) takes_context = False + self_or_cls_skipped = False filtered_params = [] if params: @@ -297,15 +300,26 @@ def function_schema( takes_context = True # Mark that the function takes context else: filtered_params.append((first_name, first_param)) + elif first_name in ("self", "cls"): + # Skip unannotated self/cls so @function_tool works on class methods. + # Bound methods already have self stripped by Python, so this handles + # the unbound case (decoration time). The caller must invoke with a + # bound method for correct runtime behavior. + self_or_cls_skipped = True else: filtered_params.append((first_name, first_param)) - # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. - for name, param in params[1:]: + # For remaining parameters: if self/cls was skipped, the second param is effectively first + # and may be a context parameter. + remaining_params = params[1:] + for idx, (name, param) in enumerate(remaining_params): ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann if origin is RunContextWrapper or origin is ToolContext: + if self_or_cls_skipped and idx == 0: + takes_context = True + continue raise UserError( f"RunContextWrapper/ToolContext param found at non-first position in function" f" {func.__name__}" @@ -409,14 +423,21 @@ def function_schema( if strict_json_schema: json_schema = ensure_strict_json_schema(json_schema) - # 5. Return as a FuncSchema dataclass + # 5. Build stored signature excluding self/cls if it was skipped + stored_sig = sig + if self_or_cls_skipped: + new_params = [p for p in sig.parameters.values() if p.name not in ("self", "cls")] + stored_sig = sig.replace(parameters=new_params) + + # 6. Return as a FuncSchema dataclass return FuncSchema( name=func_name, # Ensure description_override takes precedence even if docstring info is disabled. description=description_override or (doc_info.description if doc_info else None), params_pydantic_model=dynamic_model, params_json_schema=json_schema, - signature=sig, + signature=stored_sig, takes_context=takes_context, + self_or_cls_skipped=self_or_cls_skipped, strict_json_schema=strict_json_schema, ) diff --git a/tests/test_function_schema.py b/tests/test_function_schema.py index 9771bda99d..17e6e2c859 100644 --- a/tests/test_function_schema.py +++ b/tests/test_function_schema.py @@ -885,3 +885,132 @@ def func_with_annotated_multiple_field_constraints( with pytest.raises(ValidationError): # zero factor fs.params_pydantic_model(**{"score": 50, "factor": 0.0}) + + +def test_bound_method_self_not_in_schema(): + """Test that bound methods work normally (Python already strips self).""" + + class MyTools: + def greet(self, name: str) -> str: + return f"Hello, {name}" + + obj = MyTools() + fs = function_schema(obj.greet, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + assert "self" not in props + assert "name" in props + assert fs.params_json_schema.get("required") == ["name"] + + +def test_unbound_cls_param_skipped(): + """Test that unbound classmethods with unannotated cls have cls skipped.""" + + # Simulate a function whose first param is named cls with no annotation + code = compile("def greet(cls, name: str) -> str: ...", "", "exec") + ns: dict[str, Any] = {} + exec(code, ns) # noqa: S102 + fn = ns["greet"] + fn.__annotations__ = {"name": str, "return": str} + + fs = function_schema(fn, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + assert "cls" not in props + assert "name" in props + assert fs.self_or_cls_skipped is True + assert "cls" not in fs.signature.parameters + + +def test_bound_method_with_context_second_param(): + """Test that bound methods with RunContextWrapper as second param work correctly.""" + + class MyTools: + def greet(self, ctx: RunContextWrapper[None], name: str) -> str: + return f"Hello, {name}" + + obj = MyTools() + fs = function_schema(obj.greet, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + # self is already stripped by Python for bound methods + assert "self" not in props + assert "ctx" not in props + assert "name" in props + assert fs.takes_context is True + + +def test_method_context_not_immediately_after_self_raises(): + """Test that RunContextWrapper at position 3+ (not immediately after self) raises UserError.""" + + class MyTools: + def greet(self, name: str, ctx: RunContextWrapper[None]) -> str: + return f"Hello, {name}" + + obj = MyTools() + with pytest.raises(UserError, match="non-first position"): + function_schema(obj.greet, use_docstring_info=False) + + +def test_unbound_method_self_skipped_with_context(): + """Test that unbound methods with self+context have self skipped and context recognized.""" + + # Simulate an unbound method with self as first param + code = compile( + "def greet(self, ctx, name: str) -> str: ...", "", "exec" + ) + ns: dict[str, Any] = {} + exec(code, ns) # noqa: S102 + fn = ns["greet"] + fn.__annotations__ = {"ctx": RunContextWrapper[None], "name": str, "return": str} + + fs = function_schema(fn, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + assert "self" not in props + assert "ctx" not in props + assert "name" in props + assert fs.self_or_cls_skipped is True + assert fs.takes_context is True + assert "self" not in fs.signature.parameters + + +def test_unbound_method_to_call_args_alignment(): + """Test that to_call_args produces correct args when self was skipped.""" + + code = compile("def greet(self, name: str, count: int = 1) -> str: ...", "", "exec") + ns: dict[str, Any] = {} + exec(code, ns) # noqa: S102 + fn = ns["greet"] + fn.__annotations__ = {"name": str, "count": int, "return": str} + + fs = function_schema(fn, use_docstring_info=False) + assert fs.self_or_cls_skipped is True + + parsed = fs.params_pydantic_model(name="world", count=3) + args, kwargs = fs.to_call_args(parsed) + assert args == ["world", 3] + assert kwargs == {} + + +def test_decorator_pattern_does_not_raise(): + """Test that function_schema works on unbound methods (decorator pattern).""" + + # This simulates @function_tool applied at class definition time + class MyTools: + def search(self, query: str) -> str: + return query + + # At decoration time, MyTools.search is unbound + fs = function_schema(MyTools.search, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + assert "self" not in props + assert "query" in props + assert fs.self_or_cls_skipped is True + + +def test_regular_unannotated_first_param_still_included(): + """Test that a regular unannotated first param (not self/cls) is still included.""" + + def process(data, flag: bool = False) -> str: + return str(data) + + fs = function_schema(process, use_docstring_info=False) + props = fs.params_json_schema.get("properties", {}) + assert "data" in props