Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
self_or_cls_skipped: bool = False
"""Whether an unannotated self/cls first parameter was skipped during schema generation."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
Expand Down Expand Up @@ -285,6 +287,7 @@ def function_schema(
sig = inspect.signature(func)
params = list(sig.parameters.items())
takes_context = False
self_or_cls_skipped = False
filtered_params = []

if params:
Expand All @@ -297,15 +300,26 @@ def function_schema(
takes_context = True # Mark that the function takes context
else:
filtered_params.append((first_name, first_param))
elif first_name in ("self", "cls"):
# Skip unannotated self/cls so @function_tool works on class methods.
# Bound methods already have self stripped by Python, so this handles
# the unbound case (decoration time). The caller must invoke with a
# bound method for correct runtime behavior.
self_or_cls_skipped = True
else:
filtered_params.append((first_name, first_param))

# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
# For remaining parameters: if self/cls was skipped, the second param is effectively first
# and may be a context parameter.
remaining_params = params[1:]
for idx, (name, param) in enumerate(remaining_params):
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper or origin is ToolContext:
if self_or_cls_skipped and idx == 0:
takes_context = True
continue
raise UserError(
f"RunContextWrapper/ToolContext param found at non-first position in function"
f" {func.__name__}"
Expand Down Expand Up @@ -409,14 +423,21 @@ def function_schema(
if strict_json_schema:
json_schema = ensure_strict_json_schema(json_schema)

# 5. Return as a FuncSchema dataclass
# 5. Build stored signature excluding self/cls if it was skipped
stored_sig = sig
if self_or_cls_skipped:
new_params = [p for p in sig.parameters.values() if p.name not in ("self", "cls")]
stored_sig = sig.replace(parameters=new_params)

# 6. Return as a FuncSchema dataclass
return FuncSchema(
name=func_name,
# Ensure description_override takes precedence even if docstring info is disabled.
description=description_override or (doc_info.description if doc_info else None),
params_pydantic_model=dynamic_model,
params_json_schema=json_schema,
signature=sig,
signature=stored_sig,
takes_context=takes_context,
self_or_cls_skipped=self_or_cls_skipped,
strict_json_schema=strict_json_schema,
)
129 changes: 129 additions & 0 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,132 @@ def func_with_annotated_multiple_field_constraints(

with pytest.raises(ValidationError): # zero factor
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})


def test_bound_method_self_not_in_schema():
"""Test that bound methods work normally (Python already strips self)."""

class MyTools:
def greet(self, name: str) -> str:
return f"Hello, {name}"

obj = MyTools()
fs = function_schema(obj.greet, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
assert "self" not in props
assert "name" in props
assert fs.params_json_schema.get("required") == ["name"]


def test_unbound_cls_param_skipped():
"""Test that unbound classmethods with unannotated cls have cls skipped."""

# Simulate a function whose first param is named cls with no annotation
code = compile("def greet(cls, name: str) -> str: ...", "<test>", "exec")
ns: dict[str, Any] = {}
exec(code, ns) # noqa: S102
fn = ns["greet"]
fn.__annotations__ = {"name": str, "return": str}

fs = function_schema(fn, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
assert "cls" not in props
assert "name" in props
assert fs.self_or_cls_skipped is True
assert "cls" not in fs.signature.parameters


def test_bound_method_with_context_second_param():
"""Test that bound methods with RunContextWrapper as second param work correctly."""

class MyTools:
def greet(self, ctx: RunContextWrapper[None], name: str) -> str:
return f"Hello, {name}"

obj = MyTools()
fs = function_schema(obj.greet, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
# self is already stripped by Python for bound methods
assert "self" not in props
assert "ctx" not in props
assert "name" in props
assert fs.takes_context is True


def test_method_context_not_immediately_after_self_raises():
"""Test that RunContextWrapper at position 3+ (not immediately after self) raises UserError."""

class MyTools:
def greet(self, name: str, ctx: RunContextWrapper[None]) -> str:
return f"Hello, {name}"

obj = MyTools()
with pytest.raises(UserError, match="non-first position"):
function_schema(obj.greet, use_docstring_info=False)


def test_unbound_method_self_skipped_with_context():
"""Test that unbound methods with self+context have self skipped and context recognized."""

# Simulate an unbound method with self as first param
code = compile(
"def greet(self, ctx, name: str) -> str: ...", "<test>", "exec"
)
ns: dict[str, Any] = {}
exec(code, ns) # noqa: S102
fn = ns["greet"]
fn.__annotations__ = {"ctx": RunContextWrapper[None], "name": str, "return": str}

fs = function_schema(fn, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
assert "self" not in props
assert "ctx" not in props
assert "name" in props
assert fs.self_or_cls_skipped is True
assert fs.takes_context is True
assert "self" not in fs.signature.parameters


def test_unbound_method_to_call_args_alignment():
"""Test that to_call_args produces correct args when self was skipped."""

code = compile("def greet(self, name: str, count: int = 1) -> str: ...", "<test>", "exec")
ns: dict[str, Any] = {}
exec(code, ns) # noqa: S102
fn = ns["greet"]
fn.__annotations__ = {"name": str, "count": int, "return": str}

fs = function_schema(fn, use_docstring_info=False)
assert fs.self_or_cls_skipped is True

parsed = fs.params_pydantic_model(name="world", count=3)
args, kwargs = fs.to_call_args(parsed)
assert args == ["world", 3]
assert kwargs == {}


def test_decorator_pattern_does_not_raise():
"""Test that function_schema works on unbound methods (decorator pattern)."""

# This simulates @function_tool applied at class definition time
class MyTools:
def search(self, query: str) -> str:
return query

# At decoration time, MyTools.search is unbound
fs = function_schema(MyTools.search, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
assert "self" not in props
assert "query" in props
assert fs.self_or_cls_skipped is True


def test_regular_unannotated_first_param_still_included():
"""Test that a regular unannotated first param (not self/cls) is still included."""

def process(data, flag: bool = False) -> str:
return str(data)

fs = function_schema(process, use_docstring_info=False)
props = fs.params_json_schema.get("properties", {})
assert "data" in props
Loading