Skip to content

Commit 64bb584

Browse files
ecanlarclaude
andcommitted
fix: resolve mypy type errors and nested model description propagation
- Add type parameters to Dict and set in _resolve_pydantic_refs inner functions to fix mypy [type-arg] errors - Fix no-any-return in resolve_ref by explicitly typing the variable - Preserve Field(description=...) for direct $ref (not just allOf) in _resolve_pydantic_refs so nested model descriptions propagate correctly - Re-apply per-parameter descriptions in _get_pydantic_schema after schema generation, since Pydantic may replace them with model docstrings - Unwrap Annotated[T, Field(...)] in from_function_with_options so Annotated BaseModel parameters are parsed correctly - Propagate field_info.description to sub-schemas when parsing nested BaseModel fields in _parse_schema_from_parameter Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 845c324 commit 64bb584

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,18 @@ def _resolve_pydantic_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
138138
schema = copy.deepcopy(schema)
139139
defs = schema.get("$defs", {})
140140

141-
def resolve_ref(ref_string: str) -> Optional[Dict]:
141+
def resolve_ref(ref_string: str) -> Optional[Dict[str, Any]]:
142142
"""Resolve a $ref string like '#/$defs/Person'."""
143143
if not ref_string.startswith("#/$defs/"):
144144
return None
145145
def_name = ref_string.split("/")[-1]
146-
return defs.get(def_name)
146+
resolved: Optional[Dict[str, Any]] = defs.get(def_name)
147+
return resolved
147148

148149
def resolve_property(
149-
prop_schema: Dict, seen_refs: Optional[set] = None
150-
) -> Dict:
150+
prop_schema: Dict[str, Any],
151+
seen_refs: Optional[set[str]] = None,
152+
) -> Dict[str, Any]:
151153
"""Recursively resolve $ref in a property schema.
152154
153155
Args:
@@ -203,9 +205,16 @@ def resolve_property(
203205
if ref_string not in seen_refs:
204206
seen_refs_copy = seen_refs.copy()
205207
seen_refs_copy.add(ref_string)
208+
209+
# Preserve parameter-level description (takes precedence over model docstring)
210+
param_description = prop_schema.get("description")
211+
206212
resolved = resolve_ref(ref_string)
207213
if resolved:
208-
return resolve_property(copy.deepcopy(resolved), seen_refs_copy)
214+
result = resolve_property(copy.deepcopy(resolved), seen_refs_copy)
215+
if param_description:
216+
result["description"] = param_description
217+
return result
209218

210219
# Recursively resolve nested properties (for already-inlined objects)
211220
if "properties" in prop_schema:
@@ -346,7 +355,7 @@ def _remove_title(schema: Dict):
346355
property_schema.pop("title", None)
347356

348357

349-
def _get_pydantic_schema(func: Callable) -> Dict:
358+
def _get_pydantic_schema(func: Callable) -> Dict[str, Any]:
350359
from ..utils.context_utils import find_context_parameter
351360

352361
fields_dict = _get_fields_dict(func)
@@ -355,13 +364,26 @@ def _get_pydantic_schema(func: Callable) -> Dict:
355364
if context_param in fields_dict.keys():
356365
fields_dict.pop(context_param)
357366

367+
# Capture per-parameter descriptions before schema generation, because
368+
# Pydantic may replace them with model docstrings for nested BaseModel types.
369+
param_descriptions: Dict[str, str] = {}
370+
for name, (_, field_info) in fields_dict.items():
371+
if field_info.description:
372+
param_descriptions[name] = field_info.description
373+
358374
schema = pydantic.create_model(
359375
func.__name__, **fields_dict
360376
).model_json_schema()
361377

362378
# Resolve $ref for nested Pydantic models to inline Field descriptions
363379
schema = _resolve_pydantic_refs(schema)
364380

381+
# Re-apply per-parameter descriptions that may have been lost during
382+
# schema generation (Pydantic uses model docstrings for nested models).
383+
for name, description in param_descriptions.items():
384+
if name in schema.get("properties", {}):
385+
schema["properties"][name]["description"] = description
386+
365387
return schema
366388

367389

@@ -530,6 +552,8 @@ def from_function_with_options(
530552
except TypeError:
531553
# This can happen if func is a mock object
532554
annotation_under_future = {}
555+
# Collect Annotated field descriptions to apply after schema generation.
556+
annotated_descriptions: Dict[str, str] = {}
533557
try:
534558
for name, param in inspect.signature(func).parameters.items():
535559
if param.kind in (
@@ -541,9 +565,19 @@ def from_function_with_options(
541565
param, annotation_under_future, name
542566
)
543567

568+
# Unwrap Annotated[T, Field(...)] so the parser sees the base type.
569+
field_info = _extract_field_info_from_annotated(param.annotation)
570+
if field_info and field_info.description:
571+
annotated_descriptions[name] = field_info.description
572+
base_type = _extract_base_type_from_annotated(param.annotation)
573+
if base_type is not param.annotation:
574+
param = param.replace(annotation=base_type)
575+
544576
schema = _function_parameter_parse_util._parse_schema_from_parameter(
545577
variant, param, func.__name__
546578
)
579+
if name in annotated_descriptions:
580+
schema.description = annotated_descriptions[name]
547581
parameters_properties[name] = schema
548582
except ValueError:
549583
# If the function has complex parameter types that fail in _parse_schema_from_parameter,
@@ -563,15 +597,24 @@ def from_function_with_options(
563597
param, annotation_under_future, name
564598
)
565599

600+
# Unwrap Annotated[T, Field(...)] for the fallback path too.
601+
field_info = _extract_field_info_from_annotated(param.annotation)
602+
if field_info and field_info.description:
603+
annotated_descriptions[name] = field_info.description
604+
base_type = _extract_base_type_from_annotated(param.annotation)
605+
if base_type is not param.annotation:
606+
param = param.replace(annotation=base_type)
607+
566608
_function_parameter_parse_util._raise_for_invalid_enum_value(param)
567609

568610
json_schema_dict = _function_parameter_parse_util._generate_json_schema_for_parameter(
569611
param
570612
)
571613

572-
parameters_json_schema[name] = types.Schema.model_validate(
573-
json_schema_dict
574-
)
614+
schema_obj = types.Schema.model_validate(json_schema_dict)
615+
if name in annotated_descriptions:
616+
schema_obj.description = annotated_descriptions[name]
617+
parameters_json_schema[name] = schema_obj
575618
except Exception as e:
576619
_function_parameter_parse_util._raise_for_unsupported_param(
577620
param, func.__name__, e

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _parse_schema_from_parameter(
390390
schema.type = types.Type.OBJECT
391391
schema.properties = {}
392392
for field_name, field_info in param.annotation.model_fields.items():
393-
schema.properties[field_name] = _parse_schema_from_parameter(
393+
field_schema = _parse_schema_from_parameter(
394394
variant,
395395
inspect.Parameter(
396396
field_name,
@@ -399,6 +399,9 @@ def _parse_schema_from_parameter(
399399
),
400400
func_name,
401401
)
402+
if field_info.description:
403+
field_schema.description = field_info.description
404+
schema.properties[field_name] = field_schema
402405

403406
required_fields = [
404407
field_name

0 commit comments

Comments
 (0)