From 31c4b66b1d27f06299b94e597d32b2d67b641bcf Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 9 Jun 2026 15:59:07 +0200 Subject: [PATCH] Python: Split type checkers by target (pyright source, 5 checkers on tests/samples) Rework the typing setup along the lines of the 'too many type checkers' approach: - Pyright (strict) is now the sole source-code type checker; mypy is removed from source and its [tool.mypy] block becomes a relaxed profile used only for tests/samples. - Tests are checked by all five checkers (pyright relaxed, mypy, pyrefly, ty, zuban); samples by pyright, pyrefly, and ty. All run in a relaxed/ basic profile so authors aren't forced into over-annotation. - Add pyrightconfig.tests.json and bump sample pyright configs to basic. - Unify test/sample typing onto the same parallel fan-out used by source pyright via run_command_items in task_runner.py. - Make version-conditional imports symmetric: keep or drop the '# type: ignore' on both branches so results match across interpreter versions (local vs CI). - Update SKILL.md, DEV_SETUP.md, and CODING_STANDARD.md for the five gating checkers and pyright on source+tests+samples. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/python-code-quality.yml | 10 +- .../skills/python-code-quality/SKILL.md | 56 +- python/CODING_STANDARD.md | 13 +- python/DEV_SETUP.md | 21 +- .../a2a/agent_framework_a2a/_agent.py | 4 +- python/packages/a2a/tests/test_a2a_agent.py | 120 +++-- .../packages/a2a/tests/test_a2a_executor.py | 53 +- .../ag-ui/agent_framework_ag_ui/_agent.py | 2 +- .../ag-ui/agent_framework_ag_ui/_agent_run.py | 2 +- .../ag-ui/agent_framework_ag_ui/_client.py | 30 +- .../_message_adapters.py | 10 +- .../ag-ui/agent_framework_ag_ui/_types.py | 8 +- .../ag-ui/agent_framework_ag_ui/_utils.py | 6 +- .../agent_framework_ag_ui/_workflow_run.py | 4 +- .../agents/task_steps_agent.py | 2 +- .../agents/ui_generator_agent.py | 8 +- .../server/main.py | 2 +- python/packages/ag-ui/tests/ag_ui/conftest.py | 28 +- .../golden/test_scenario_agentic_chat.py | 4 +- .../golden/test_scenario_backend_tools.py | 4 +- .../test_scenario_deterministic_state.py | 4 +- .../test_scenario_generative_ui_agent.py | 9 +- .../test_scenario_generative_ui_tool.py | 4 +- .../tests/ag_ui/golden/test_scenario_hitl.py | 4 +- .../golden/test_scenario_predictive_state.py | 4 +- .../golden/test_scenario_shared_state.py | 4 +- .../ag_ui/golden/test_scenario_subgraphs.py | 2 +- .../ag_ui/golden/test_scenario_workflow.py | 43 +- .../packages/ag-ui/tests/ag_ui/sse_helpers.py | 2 +- .../ag-ui/tests/ag_ui/test_ag_ui_client.py | 22 +- .../ag_ui/test_agent_wrapper_comprehensive.py | 4 +- .../tests/ag_ui/test_approval_result_event.py | 2 +- .../ag-ui/tests/ag_ui/test_endpoint.py | 6 +- .../tests/ag_ui/test_event_converters.py | 20 +- .../ag-ui/tests/ag_ui/test_http_round_trip.py | 10 +- .../tests/ag_ui/test_message_adapters.py | 52 +- .../ag-ui/tests/ag_ui/test_message_hygiene.py | 10 +- .../ag-ui/tests/ag_ui/test_multi_turn.py | 11 +- python/packages/ag-ui/tests/ag_ui/test_run.py | 91 ++-- .../ag-ui/tests/ag_ui/test_run_common.py | 30 +- .../tests/ag_ui/test_structured_output.py | 12 +- .../ag-ui/tests/ag_ui/test_tooling.py | 15 +- .../packages/ag-ui/tests/ag_ui/test_types.py | 68 ++- .../packages/ag-ui/tests/ag_ui/test_utils.py | 2 +- .../ag-ui/tests/ag_ui/test_workflow_agent.py | 2 +- .../ag-ui/tests/ag_ui/test_workflow_run.py | 111 ++-- .../_bedrock_client.py | 4 +- .../agent_framework_anthropic/_chat_client.py | 26 +- .../_foundry_client.py | 4 +- .../_vertex_client.py | 4 +- .../anthropic/tests/test_anthropic_client.py | 19 +- .../_context_provider.py | 4 +- .../tests/test_aisearch_context_provider.py | 61 ++- .../_context_provider.py | 6 +- .../_detection.py | 2 +- .../_file_search.py | 2 +- .../tests/cu/test_context_provider.py | 38 +- .../tests/cu/test_integration.py | 20 +- .../tests/test_cosmos_history_provider.py | 16 +- .../_orchestration.py | 4 +- .../_serialization.py | 2 +- .../_workflow.py | 4 +- .../test_09_workflow_shared_state.py | 6 +- .../packages/azurefunctions/tests/test_app.py | 8 +- .../tests/test_orchestration.py | 6 +- .../agent_framework_bedrock/__init__.py | 4 +- .../agent_framework_bedrock/_chat_client.py | 18 +- .../_embedding_client.py | 8 +- .../bedrock/test_bedrock_embedding_client.py | 16 +- .../bedrock/tests/test_bedrock_client.py | 6 +- .../tests/test_bedrock_structured_output.py | 17 +- .../packages/chatkit/tests/test_converter.py | 26 +- .../packages/chatkit/tests/test_streaming.py | 15 +- .../claude/agent_framework_claude/_agent.py | 20 +- .../claude/tests/test_claude_agent.py | 9 +- .../_acquire_token.py | 4 +- .../packages/core/agent_framework/_agents.py | 34 +- .../packages/core/agent_framework/_clients.py | 8 +- .../core/agent_framework/_evaluation.py | 10 +- .../core/agent_framework/_feature_stage.py | 4 +- .../_harness/_background_agents.py | 8 +- .../core/agent_framework/_harness/_memory.py | 2 +- python/packages/core/agent_framework/_mcp.py | 24 +- .../core/agent_framework/_middleware.py | 18 +- .../core/agent_framework/_serialization.py | 10 +- .../core/agent_framework/_sessions.py | 2 +- .../core/agent_framework/_settings.py | 4 +- .../packages/core/agent_framework/_skills.py | 10 +- .../core/agent_framework/_telemetry.py | 10 +- .../packages/core/agent_framework/_tools.py | 74 ++- .../packages/core/agent_framework/_types.py | 18 +- .../core/agent_framework/_workflows/_agent.py | 18 +- .../_workflows/_agent_executor.py | 4 +- .../agent_framework/_workflows/_checkpoint.py | 2 +- .../_workflows/_checkpoint_encoding.py | 2 +- .../core/agent_framework/_workflows/_edge.py | 16 +- .../agent_framework/_workflows/_events.py | 4 +- .../_workflows/_function_executor.py | 6 +- .../agent_framework/_workflows/_functional.py | 4 +- .../_workflows/_model_utils.py | 4 +- .../_workflows/_runner_context.py | 2 +- .../_workflows/_typing_utils.py | 2 +- .../agent_framework/_workflows/_validation.py | 4 +- .../core/agent_framework/_workflows/_viz.py | 10 +- .../_workflows/_workflow_builder.py | 22 +- .../_workflows/_workflow_context.py | 2 +- .../_workflows/_workflow_executor.py | 4 +- .../core/agent_framework/exceptions.py | 4 +- .../core/agent_framework/foundry/__init__.pyi | 12 +- .../core/agent_framework/observability.py | 30 +- .../packages/core/agent_framework/security.py | 8 +- python/packages/core/pyproject.toml | 2 +- python/packages/core/tests/__init__.py | 1 + python/packages/core/tests/conftest.py | 2 +- python/packages/core/tests/core/__init__.py | 1 + python/packages/core/tests/core/conftest.py | 19 +- .../packages/core/tests/core/test_agents.py | 121 +++-- .../core/test_as_tool_kwargs_propagation.py | 18 +- .../packages/core/tests/core/test_clients.py | 88 +-- .../core/tests/core/test_compaction.py | 8 +- .../core/tests/core/test_embedding_client.py | 8 +- .../core/tests/core/test_embedding_types.py | 6 +- .../core/tests/core/test_feature_stage.py | 63 +-- .../core/tests/core/test_foundry_namespace.py | 2 +- .../core/test_function_invocation_logic.py | 509 ++++++++++-------- .../core/tests/core/test_harness_agent.py | 68 +-- .../core/test_harness_background_agents.py | 10 +- .../tests/core/test_harness_file_access.py | 70 +-- .../core/tests/core/test_harness_memory.py | 89 +-- .../core/tests/core/test_harness_mode.py | 30 +- .../core/tests/core/test_harness_todo.py | 52 +- .../tests/core/test_hyperlight_namespace.py | 2 +- .../core/tests/core/test_local_eval.py | 82 +-- python/packages/core/tests/core/test_mcp.py | 464 ++++++++-------- .../core/tests/core/test_mcp_observability.py | 44 +- .../core/tests/core/test_mcp_skills.py | 190 +++---- .../core/tests/core/test_middleware.py | 52 +- .../core/test_middleware_context_result.py | 14 +- .../tests/core/test_middleware_with_agent.py | 76 +-- .../tests/core/test_middleware_with_chat.py | 40 +- .../core/tests/core/test_observability.py | 498 +++++++++-------- .../tests/core/test_optional_dependencies.py | 2 +- .../packages/core/tests/core/test_sessions.py | 46 +- .../packages/core/tests/core/test_skills.py | 56 +- python/packages/core/tests/core/test_tools.py | 86 +-- .../core/test_tools_future_annotations.py | 2 +- python/packages/core/tests/core/test_types.py | 237 ++++---- python/packages/core/tests/test_security.py | 80 +-- .../tests/workflow/test_agent_executor.py | 16 +- .../test_agent_executor_tool_calls.py | 2 +- .../core/tests/workflow/test_agent_utils.py | 14 +- .../core/tests/workflow/test_checkpoint.py | 24 +- .../workflow/test_checkpoint_validation.py | 4 +- .../core/tests/workflow/test_executor.py | 28 +- .../tests/workflow/test_executor_future.py | 2 +- .../tests/workflow/test_full_conversation.py | 4 +- .../tests/workflow/test_function_executor.py | 16 +- .../workflow/test_function_executor_future.py | 2 +- .../workflow/test_functional_workflow.py | 14 +- .../tests/workflow/test_output_designation.py | 6 +- .../test_output_executors_contract.py | 6 +- .../test_request_info_and_response.py | 4 +- .../tests/workflow/test_request_info_mixin.py | 34 +- .../core/tests/workflow/test_runner.py | 6 +- .../core/tests/workflow/test_serialization.py | 4 +- .../test_strict_mode_event_labeling.py | 2 +- .../core/tests/workflow/test_sub_workflow.py | 8 +- .../core/tests/workflow/test_typing_utils.py | 14 +- .../core/tests/workflow/test_validation.py | 6 +- .../core/tests/workflow/test_workflow.py | 6 +- .../tests/workflow/test_workflow_agent.py | 35 +- .../test_workflow_agent_intermediate.py | 42 +- .../tests/workflow/test_workflow_builder.py | 4 +- .../tests/workflow/test_workflow_context.py | 12 +- .../tests/workflow/test_workflow_kwargs.py | 10 +- .../workflow/test_workflow_observability.py | 36 +- .../tests/workflow/test_workflow_states.py | 2 +- .../agent_framework_declarative/_loader.py | 10 +- .../agent_framework_declarative/_models.py | 74 +-- .../_workflows/_declarative_base.py | 22 +- .../_workflows/_declarative_builder.py | 2 +- .../_workflows/_executors_basic.py | 4 +- .../_workflows/_executors_mcp.py | 2 +- .../_workflows/_http_handler.py | 2 +- .../_workflows/_powerfx_functions.py | 2 +- .../_workflows/_state.py | 2 +- .../tests/test_declarative_loader.py | 28 +- .../tests/test_declarative_models.py | 20 +- .../test_default_http_request_handler.py | 2 +- .../tests/test_default_mcp_tool_handler.py | 30 +- .../tests/test_function_tool_executor.py | 2 +- .../declarative/tests/test_graph_coverage.py | 26 +- .../declarative/tests/test_graph_executors.py | 22 +- .../test_http_request_yaml_integration.py | 2 +- .../tests/test_powerfx_functions.py | 6 +- .../tests/test_workflow_factory.py | 30 +- .../agent_framework_devui/_conversations.py | 4 +- .../devui/agent_framework_devui/_mapper.py | 6 +- .../devui/agent_framework_devui/_server.py | 12 +- .../devui/agent_framework_devui/_utils.py | 20 +- .../devui/tests/devui/capture_messages.py | 17 +- python/packages/devui/tests/devui/conftest.py | 2 +- .../tests/devui/test_approval_validation.py | 4 +- .../devui/tests/devui/test_checkpoints.py | 4 +- .../devui/tests/devui/test_cleanup_hooks.py | 4 +- .../devui/tests/devui/test_conversations.py | 6 +- .../devui/tests/devui/test_discovery.py | 1 + .../devui/tests/devui/test_execution.py | 8 +- .../packages/devui/tests/devui/test_mapper.py | 4 +- .../devui/test_openai_sdk_integration.py | 17 +- .../packages/devui/tests/devui/test_server.py | 4 +- .../tests/devui/test_ui_memory_regression.py | 2 +- .../agent_framework_durabletask/_entities.py | 2 +- .../agent_framework_durabletask/_executors.py | 4 +- .../agent_framework_durabletask/_worker.py | 4 +- .../tests/integration_tests/conftest.py | 11 +- .../test_01_dt_single_agent.py | 14 +- .../test_02_dt_multi_agent.py | 14 +- .../test_03_dt_single_agent_streaming.py | 17 +- ..._dt_single_agent_orchestration_chaining.py | 13 +- ...t_multi_agent_orchestration_concurrency.py | 13 +- ..._multi_agent_orchestration_conditionals.py | 13 +- ...t_07_dt_single_agent_orchestration_hitl.py | 13 +- .../packages/durabletask/tests/test_client.py | 2 +- .../tests/test_durable_agent_state.py | 4 +- .../tests/test_durable_entities.py | 2 +- .../tests/test_orchestration_context.py | 2 +- .../packages/durabletask/tests/test_shim.py | 8 +- .../foundry/agent_framework_foundry/_agent.py | 26 +- .../agent_framework_foundry/_chat_client.py | 32 +- .../_embedding_client.py | 10 +- .../agent_framework_foundry/_foundry_evals.py | 24 +- .../_memory_provider.py | 2 +- .../_to_prompt_agent.py | 4 +- .../tests/foundry/test_foundry_agent.py | 18 +- .../tests/foundry/test_foundry_chat_client.py | 47 +- .../foundry/test_foundry_memory_provider.py | 29 +- .../tests/foundry/test_to_prompt_agent.py | 6 +- .../foundry/tests/test_foundry_evals.py | 64 +-- .../_invocations.py | 2 +- .../_responses.py | 6 +- .../foundry_hosting/tests/test_responses.py | 49 +- .../tests/test_responses_int.py | 12 +- .../_foundry_local_client.py | 10 +- .../agent_framework_gemini/_chat_client.py | 16 +- .../gemini/tests/test_gemini_client.py | 75 ++- .../agent_framework_github_copilot/_agent.py | 13 +- .../tests/test_github_copilot_agent.py | 182 ++++--- .../hyperlight/test_hyperlight_codeact.py | 26 +- .../lab/gaia/agent_framework_lab_gaia/gaia.py | 4 +- .../agent_framework_lab_lightning/__init__.py | 10 +- .../agent_framework_lab_tau2/_tau2_utils.py | 8 +- .../tau2/agent_framework_lab_tau2/runner.py | 14 +- .../lab/tau2/tests/test_message_utils.py | 2 +- .../lab/tau2/tests/test_tau2_utils.py | 9 +- .../agent_framework_mem0/_context_provider.py | 6 +- .../mem0/tests/test_mem0_context_provider.py | 77 ++- .../_embedding_client.py | 4 +- .../mistral/test_mistral_embedding_client.py | 2 +- .../agent_framework_monty/_monty_bridge.py | 2 +- .../monty/tests/monty/test_monty_codeact.py | 10 +- .../monty/test_monty_codeact_integration.py | 4 +- .../agent_framework_ollama/_chat_client.py | 12 +- .../_embedding_client.py | 8 +- .../ollama/tests/test_ollama_chat_client.py | 12 +- .../agent_framework_openai/_chat_client.py | 42 +- .../_chat_completion_client.py | 42 +- .../_embedding_client.py | 12 +- .../openai/agent_framework_openai/_shared.py | 6 +- .../packages/openai/tests/openai/conftest.py | 10 +- .../tests/openai/test_openai_chat_client.py | 210 +++++--- .../openai/test_openai_chat_client_azure.py | 73 ++- .../test_openai_chat_completion_client.py | 42 +- ...est_openai_chat_completion_client_azure.py | 38 +- ...test_openai_chat_completion_client_base.py | 6 +- .../openai/test_openai_embedding_client.py | 7 +- .../test_openai_embedding_client_azure.py | 30 +- .../openai/tests/openai/test_openai_shared.py | 15 +- .../_base_group_chat_orchestrator.py | 6 +- .../_concurrent.py | 4 +- .../_group_chat.py | 4 +- .../_handoff.py | 6 +- .../_magentic.py | 12 +- .../orchestrations/tests/test_concurrent.py | 3 +- .../orchestrations/tests/test_group_chat.py | 43 +- .../orchestrations/tests/test_handoff.py | 127 +++-- .../orchestrations/tests/test_magentic.py | 32 +- ..._orchestration_intermediate_vs_terminal.py | 101 ++-- .../tests/test_orchestration_request_info.py | 25 +- .../orchestrations/tests/test_sequential.py | 10 +- .../agent_framework_purview/_client.py | 8 +- .../agent_framework_purview/_middleware.py | 4 +- .../agent_framework_purview/_models.py | 38 +- .../agent_framework_purview/_settings.py | 2 +- .../tests/purview/test_chat_middleware.py | 64 ++- .../purview/tests/purview/test_middleware.py | 17 +- .../purview/tests/purview/test_processor.py | 68 ++- .../tests/purview/test_purview_client.py | 29 +- .../tests/purview/test_purview_models.py | 5 +- .../purview/tests/purview/test_settings.py | 1 + .../_context_provider.py | 10 +- .../_history_provider.py | 4 +- python/packages/redis/tests/test_providers.py | 55 +- .../agent_framework_tools/shell/_killtree.py | 4 +- .../tools/tests/test_docker_shell_tool.py | 2 +- .../tests/test_shell_environment_provider.py | 12 +- python/pyproject.toml | 75 ++- python/pyrefly.samples.toml | 27 + python/pyrefly.toml | 17 + python/pyrightconfig.samples.json | 8 +- python/pyrightconfig.samples.py310.json | 8 +- python/pyrightconfig.tests.json | 9 + .../02-agents/a2a/a2a_stream_reconnection.py | 6 +- .../samples/02-agents/background_responses.py | 11 +- .../chat_client/chat_response_cancellation.py | 7 +- .../azure_ai_search/search_context_agentic.py | 2 +- .../context_providers/redis/redis_basics.py | 8 +- .../simple_context_provider.py | 11 +- .../azure_openai_responses_agent.py | 6 +- .../02-agents/devui/agent_foundry/__init__.py | 2 +- .../02-agents/devui/agent_weather/__init__.py | 2 +- .../02-agents/devui/workflow_spam/__init__.py | 2 +- .../devui/workflow_with_agents/__init__.py | 2 +- .../devui/workflow_with_agents/workflow.py | 4 +- .../02-agents/evaluation/evaluate_agent.py | 2 +- .../evaluation/evaluate_multimodal.py | 2 +- .../evaluation/evaluate_with_expected.py | 2 +- .../02-agents/harness/console/agent_runner.py | 4 +- .../samples/02-agents/harness/console/app.py | 10 +- .../harness/console/commands/todo_handler.py | 4 +- .../console/components/scroll_panel.py | 2 +- .../console/observers/planning_models.py | 3 +- .../02-agents/middleware/chat_middleware.py | 2 +- .../providers/amazon/bedrock_chat_client.py | 4 +- .../providers/anthropic/anthropic_skills.py | 2 + .../anthropic/anthropic_with_shell.py | 4 +- .../providers/custom/custom_agent.py | 30 +- ...ndry_chat_client_code_interpreter_files.py | 2 +- .../foundry_chat_client_with_hosted_mcp.py | 24 +- .../providers/foundry/foundry_local_agent.py | 6 +- .../foundry/foundry_prompt_agents.py | 2 + .../github_copilot/github_copilot_basic.py | 22 +- .../github_copilot_with_file_operations.py | 6 +- .../github_copilot_with_function_approval.py | 26 +- ...ub_copilot_with_instruction_directories.py | 30 +- .../github_copilot/github_copilot_with_mcp.py | 17 +- ...ithub_copilot_with_multiple_permissions.py | 6 +- .../github_copilot_with_session.py | 14 +- .../github_copilot_with_shell.py | 6 +- .../github_copilot/github_copilot_with_url.py | 6 +- .../providers/ollama/ollama_chat_client.py | 6 +- .../client_streaming_image_generation.py | 3 +- .../openai/client_with_hosted_mcp.py | 24 +- .../openai/client_with_local_shell.py | 2 + python/samples/02-agents/response_stream.py | 10 +- .../security/repo_confidentiality_example.py | 2 +- .../file_based_skill/file_based_skill.py | 2 +- .../skills/mixed_skills/mixed_skills.py | 2 +- .../skills/script_approval/script_approval.py | 6 +- .../skills/skill_filtering/skill_filtering.py | 2 +- .../tools/function_tool_with_approval.py | 4 + ...unction_tool_with_approval_and_sessions.py | 4 + .../azure_ai_agents_with_shared_session.py | 3 - .../workflow_as_agent_human_in_the_loop.py | 2 +- .../agents/workflow_as_agent_kwargs.py | 2 +- .../composition/sub_workflow_kwargs.py | 3 +- .../control-flow/edge_condition.py | 7 +- .../intermediate_vs_terminal_outputs.py | 4 +- .../multi_selection_edge_group.py | 9 +- .../control-flow/switch_case_edge_group.py | 7 +- .../agent_to_function_tool/main.py | 3 +- .../declarative/customer_support/main.py | 14 +- .../declarative/deep_research/main.py | 4 +- .../invoke_foundry_toolbox_mcp/main.py | 5 +- .../guessing_game_with_human_input.py | 4 +- .../state-management/state_with_agents.py | 5 +- python/samples/04-hosting/a2a/a2a_server.py | 2 +- .../04-hosting/a2a/agent_definitions.py | 2 +- .../03_reliable_streaming/function_app.py | 8 +- .../function_app.py | 2 +- .../09_workflow_shared_state/function_app.py | 5 +- .../function_app.py | 7 +- .../11_workflow_parallel/function_app.py | 8 +- .../12_workflow_hitl/function_app.py | 3 +- .../durabletask/01_single_agent/sample.py | 4 +- .../durabletask/02_multi_agent/sample.py | 4 +- .../03_single_agent_streaming/client.py | 2 +- .../03_single_agent_streaming/sample.py | 4 +- .../03_single_agent_streaming/worker.py | 4 +- .../sample.py | 4 +- .../sample.py | 4 +- .../sample.py | 4 +- .../sample.py | 4 +- .../responses/using_deployed_agent.py | 2 +- python/scripts/task_runner.py | 72 +++ python/scripts/workspace_poe_tasks.py | 271 +++++++--- .../getting_started/test_agent_samples.py | 94 ++-- .../test_chat_client_samples.py | 18 +- .../getting_started/test_threads_samples.py | 10 +- .../samples/hosting/test_toolbox_endpoint.py | 7 +- python/ty.samples.toml | 17 + python/uv.lock | 67 +++ 402 files changed, 5094 insertions(+), 3931 deletions(-) create mode 100644 python/pyrefly.samples.toml create mode 100644 python/pyrefly.toml create mode 100644 python/pyrightconfig.tests.json create mode 100644 python/ty.samples.toml diff --git a/.github/workflows/python-code-quality.yml b/.github/workflows/python-code-quality.yml index 6527a89cd83..e80f0adb88c 100644 --- a/.github/workflows/python-code-quality.yml +++ b/.github/workflows/python-code-quality.yml @@ -109,8 +109,8 @@ jobs: - name: Run markdown code lint run: uv run poe markdown-code-lint - mypy: - name: Mypy Checks + test-typing: + name: Test Typing Checks if: "!cancelled()" strategy: fail-fast: false @@ -135,7 +135,5 @@ jobs: os: ${{ runner.os }} env: UV_CACHE_DIR: /tmp/.uv-cache - - name: Run Mypy - env: - GITHUB_BASE_REF: ${{ github.event.pull_request.base.ref || github.base_ref || 'main' }} - run: uv run python scripts/workspace_poe_tasks.py ci-mypy + - name: Run tests/samples type checkers (mypy, pyrefly, ty) + run: uv run python scripts/workspace_poe_tasks.py ci-test-typing diff --git a/python/.github/skills/python-code-quality/SKILL.md b/python/.github/skills/python-code-quality/SKILL.md index 29ac63e4fe5..d36aa07794f 100644 --- a/python/.github/skills/python-code-quality/SKILL.md +++ b/python/.github/skills/python-code-quality/SKILL.md @@ -21,13 +21,22 @@ uv run poe syntax -C # Check only uv run poe syntax -S # Samples only # Type checking -uv run poe pyright # Pyright fan-out across packages +# +# Division of labor (see "Type checking architecture" below): +# - Pyright (strict) is the source-code type checker. +# - Pyright (relaxed `basic`), mypy, pyrefly, ty, zuban all check the TESTS; +# pyright/pyrefly/ty also check the SAMPLES (mypy/zuban skip script-style samples). +uv run poe pyright # Pyright (strict) over SOURCE, fan-out across packages uv run poe pyright -P core uv run poe pyright -A -uv run poe mypy # MyPy fan-out across packages +uv run poe test-typing # mypy + pyrefly + ty + zuban + pyright over each package's TESTS +uv run poe test-typing -P core +uv run poe test-typing -S # samples (pyrefly + ty + pyright) +uv run poe test-typing -P core --checker mypy # narrow to one checker (repeatable) +uv run poe test-typing -P core --checker pyright # relaxed pyright over the tests +uv run poe mypy # alias: MyPy over the tests only uv run poe mypy -P core -uv run poe mypy -A -uv run poe typing # Both pyright and mypy +uv run poe typing # Pyright (source) + the tests checkers uv run poe typing -P core uv run poe typing -A @@ -67,6 +76,33 @@ when markdown files change, and sample syntax lint/pyright only when files under `samples/` change. They intentionally do not run workspace `pyright` or `mypy` by default. +## Type checking architecture + +Following the "too many type checkers" approach, type checkers are split by target: + +| Target | Checker(s) | Mode | Config | +|--------|-----------|------|--------| +| Source (`agent_framework*`) | **pyright** | strict | `[tool.pyright]` in `pyproject.toml` | +| Tests | pyright, mypy, pyrefly, ty, zuban | relaxed/basic | `pyrightconfig.tests.json`, `[tool.mypy]`, `pyrefly.toml`, `ty` rules | +| Samples | pyright, pyrefly, ty | basic | `pyrightconfig.samples.json`, `pyrefly.samples.toml`, `ty.samples.toml` | + +- **Pyright is the only *strict* source-code checker**, and it ALSO runs in a relaxed + `basic` profile over the tests and samples (so the surfaces customers copy from are + validated by every checker, including pyright). MyPy was removed from source; its + `[tool.mypy]` block is now a *relaxed* profile used only for tests/samples. +- The extra checkers run over tests/samples because those exercise the public API the way + users do. The profile is intentionally relaxed (private access allowed, untyped test + bodies allowed) so authors aren't forced into ugly over-annotation. +- **Gating checkers** are `pyright`, `mypy`, `pyrefly`, `ty`, and `zuban` — all five run by + default and gate CI. `zuban` is the strictest of the mypy-compatible pair, so the same + `[tool.mypy]` config yields more findings; suppress zuban-only friction with shared + `# type: ignore[code]`. Suppress relaxed-pyright friction with `# pyright: ignore[rule]`. +- **Samples** add `pyright` to `pyrefly` + `ty` — mypy/zuban can't resolve script-style + sample layouts (numeric-prefixed dirs, duplicate `main.py`), but pyright handles them. +- The strict source-pyright (`[tool.pyright]`) enforces `reportUnnecessaryTypeIgnoreComment` + and excludes tests/samples; the relaxed test/sample pyright configs do not flag unnecessary + ignores. + ## Ruff Configuration - Line length: 120 @@ -77,8 +113,12 @@ They intentionally do not run workspace `pyright` or `mypy` by default. ## Pyright Configuration -- Strict mode enabled -- Excludes: tests, .venv, packages/devui/frontend +- **Source**: strict mode (`[tool.pyright]`), `reportUnnecessaryTypeIgnoreComment = "error"`, + excludes tests, samples, .venv, packages/devui/frontend. +- **Tests**: relaxed `basic` profile (`pyrightconfig.tests.json`) — private import/usage and + not-required TypedDict access allowed; runs as the `pyright` checker in `test-typing`. +- **Samples**: relaxed `basic` profile (`pyrightconfig.samples.json`, with a py310 variant) — + runs as the `pyright` checker in `test-typing -S`. ## Parallel Execution @@ -90,6 +130,6 @@ in-process with streaming output. CI splits into 4 parallel jobs: 1. **Pre-commit hooks** — lightweight hooks (SKIP=poe-check) -2. **Package checks** — syntax/pyright via check-packages +2. **Package checks** — syntax/pyright (source) via check-packages 3. **Samples & markdown** — `check -S` plus `markdown-code-lint` -4. **Mypy** — change-detected mypy checks +4. **Test Typing** — change-detected mypy/pyrefly/ty over tests (`ci-test-typing`) diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index a9140ca353e..9762f89242a 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -92,11 +92,18 @@ Use typing as a helper first and suppressions as a last resort: - **Prefer explicit typing before suppression**: Start with clearer type annotations, helper types, overloads, protocols, or refactoring dynamic code into typed helpers. Prioritize performance over completeness of typing, but make a good-faith effort to reduce uncertainty with typing before ignoring. Prefer to use a cast over a typeguard function since that does add overhead. - **Avoid redundant casts**: Do not add `cast(...)` if the type already matches; casts should be reserved for - unavoidable narrowing where the runtime contract is known, we will use mypy's check on redundant casts to enforce this. + unavoidable narrowing where the runtime contract is known. - **Avoid multiple assignments**: Avoid assigning multiple variables just to get typing to pass, that has performance impact while typing should not have that. -- **Line-level pyright ignores only**: If suppression is still required, use a line-level rule-specific ignore +- **Source vs tests/samples**: Source code (`agent_framework*`) is checked **by pyright in strict mode** — use + `# pyright: ignore[...]` there, never `# type: ignore` (strict pyright flags unnecessary ignores as errors). Tests + and samples are checked by pyright (relaxed `basic`), mypy, pyrefly, ty (and zuban on tests) in a relaxed/basic + profile; prefer real fixes (`isinstance`, `cast`, annotations, asserts for Optional access) over per-line ignores, + and keep test/sample bodies readable rather than over-annotated. When a relaxed-pyright suppression is genuinely + needed in tests/samples, use `# pyright: ignore[rule]`; the relaxed test/sample configs do not flag unnecessary + ignores, so combine with a mypy/zuban `# type: ignore[code]` on the same line only where both are required. +- **Line-level pyright ignores only**: If suppression is still required in source, use a line-level rule-specific ignore (`# pyright: ignore[reportGeneralTypeIssues]`), file-level is allowed if there is a compelling reason for it, that should be documented right beneath the ignore. - Never change the global suppression flags for mypy and pyright unless the dev team okays it. + Never change the global suppression flags unless the dev team okays it. - **Private usage boundary**: Accessing private members across `agent_framework*` packages can be acceptable for this codebase, but private member usage for non-Agent Framework dependencies should remain flagged. diff --git a/python/DEV_SETUP.md b/python/DEV_SETUP.md index f2c899ed75f..ab72cf44d36 100644 --- a/python/DEV_SETUP.md +++ b/python/DEV_SETUP.md @@ -289,23 +289,36 @@ uv run poe -A # aggregate sweep where supported ``` #### `pyright` -Run Pyright type checking: +Run Pyright type checking. Pyright is the **strict source-code type checker**, and also runs +in a relaxed `basic` profile over the tests + samples (as one of the `test-typing` checkers): ```bash uv run poe pyright uv run poe pyright -P core uv run poe pyright -A ``` +#### `test-typing` +Run the **tests + samples** type checkers. Source code is owned by strict Pyright; the tests +and samples are checked by `pyright` (relaxed), `mypy`, `pyrefly`, `ty`, and `zuban` in a +deliberately relaxed/basic profile so real public-API type errors surface without forcing +test/sample authors to fully annotate their code. All five gate CI: +```bash +uv run poe test-typing # all checkers over every package's tests +uv run poe test-typing -P core # one package +uv run poe test-typing -S # samples (pyright + pyrefly + ty; mypy/zuban skip script-style samples) +uv run poe test-typing -P core --checker mypy # narrow to one checker (repeatable) +uv run poe test-typing -P core --checker pyright # relaxed pyright over the tests +``` + #### `mypy` -Run MyPy type checking: +Convenience alias that runs MyPy over the test suite (MyPy no longer runs on source): ```bash uv run poe mypy uv run poe mypy -P core -uv run poe mypy -A ``` #### `typing` -Run both Pyright and MyPy: +Run Pyright over source **and** the tests/samples checkers: ```bash uv run poe typing uv run poe typing -P core diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 33c8239a1cc..c2bf1d6ffb2 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -344,7 +344,7 @@ def run( **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( # pyright: ignore[reportIncompatibleMethodOverride] + def run( self, messages: AgentRunInputs | None = None, *, @@ -464,7 +464,7 @@ async def _map_a2a_stream( if session is None: raise RuntimeError("Provider session must be available when context providers are configured.") await provider.before_run( - agent=self, # type: ignore[arg-type] + agent=self, session=session, context=session_context, state=session.state.setdefault(provider.source_id, {}), diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 746a9b86641..ea271f7379b 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterator -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -126,14 +126,18 @@ def mock_a2a_client() -> MockA2AClient: @fixture def a2a_agent(mock_a2a_client: MockA2AClient) -> A2AAgent: """Fixture that provides an A2AAgent with a mock client.""" - return A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + return A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) def test_a2a_agent_initialization_with_client(mock_a2a_client: MockA2AClient) -> None: """Test A2AAgent initialization with provided client.""" # Use model_construct to bypass Pydantic validation for mock objects agent = A2AAgent( - name="Test Agent", id="test-agent-123", description="A test agent", client=mock_a2a_client, http_client=None + name="Test Agent", + id="test-agent-123", + description="A test agent", + client=cast(Any, mock_a2a_client), + http_client=None, ) assert agent.name == "Test Agent" @@ -148,7 +152,7 @@ def test_a2a_agent_defaults_name_description_from_agent_card(mock_a2a_client: Mo mock_card.name = "Card Agent Name" mock_card.description = "Card agent description" - agent = A2AAgent(agent_card=mock_card, client=mock_a2a_client, http_client=None) + agent = A2AAgent(agent_card=mock_card, client=cast(Any, mock_a2a_client), http_client=None) assert agent.name == "Card Agent Name" assert agent.description == "Card agent description" @@ -164,7 +168,7 @@ def test_a2a_agent_explicit_name_description_overrides_agent_card(mock_a2a_clien name="Explicit Name", description="Explicit description", agent_card=mock_card, - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), http_client=None, ) @@ -182,7 +186,7 @@ def test_a2a_agent_empty_string_name_description_not_overridden(mock_a2a_client: name="", description="", agent_card=mock_card, - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), http_client=None, ) @@ -327,7 +331,7 @@ def test_get_uri_data_invalid_uri() -> None: def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None: """Test A2A parts to contents conversion.""" - agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None) + agent = A2AAgent(name="Test Agent", client=cast(Any, MockA2AClient()), http_client=None) # Create A2A parts parts = [Part(text="First part"), Part(text="Second part")] @@ -431,7 +435,7 @@ async def test_context_manager_cleanup() -> None: mock_http_client = AsyncMock() mock_a2a_client = MagicMock() - agent = A2AAgent(client=mock_a2a_client) + agent = A2AAgent(client=cast(Any, mock_a2a_client)) agent._http_client = mock_http_client # Test context manager cleanup @@ -447,7 +451,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None: mock_a2a_client = MagicMock() - agent = A2AAgent(client=mock_a2a_client, http_client=None) + agent = A2AAgent(client=cast(Any, mock_a2a_client), http_client=None) # This should not raise any errors async with agent: @@ -535,7 +539,7 @@ def test_prepare_message_for_a2a_a2a_session_context_id_takes_precedence() -> No def test_parse_contents_from_a2a_with_data_part() -> None: """Test conversion of A2A data Part.""" from google.protobuf.json_format import ParseDict - from google.protobuf.struct_pb2 import Struct, Value + from google.protobuf.struct_pb2 import Struct, Value # ty: ignore[unresolved-import] agent = A2AAgent(client=MagicMock(), http_client=None) @@ -553,6 +557,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None: # MessageToJson may format slightly differently — verify the parsed structure import json + assert contents[0].text is not None parsed = json.loads(contents[0].text) assert parsed["key"] == "value" assert parsed["number"] == 42 @@ -658,7 +663,7 @@ def test_transport_negotiation_both_fail() -> None: def test_create_timeout_config_httpx_timeout() -> None: """Test _create_timeout_config with httpx.Timeout object returns it unchanged.""" - agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None) + agent = A2AAgent(name="Test Agent", client=cast(Any, MockA2AClient()), http_client=None) custom_timeout = httpx.Timeout(connect=15.0, read=180.0, write=20.0, pool=8.0) timeout_config = agent._create_timeout_config(custom_timeout) @@ -672,10 +677,10 @@ def test_create_timeout_config_httpx_timeout() -> None: def test_create_timeout_config_invalid_type() -> None: """Test _create_timeout_config with invalid type raises TypeError.""" - agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None) + agent = A2AAgent(name="Test Agent", client=cast(Any, MockA2AClient()), http_client=None) with raises(TypeError, match="Invalid timeout type: . Expected float, httpx.Timeout, or None."): - agent._create_timeout_config("invalid") + agent._create_timeout_config(cast(Any, "invalid")) def test_a2a_agent_initialization_with_timeout_parameter() -> None: @@ -802,8 +807,9 @@ async def test_working_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a assert isinstance(response, AgentResponse) assert response.continuation_token is not None - assert response.continuation_token["task_id"] == "task-wip" - assert response.continuation_token["context_id"] == "ctx-1" + token = cast(dict[str, Any], response.continuation_token) + assert token["task_id"] == "task-wip" + assert token["context_id"] == "ctx-1" async def test_submitted_task_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -813,7 +819,8 @@ async def test_submitted_task_emits_continuation_token(a2a_agent: A2AAgent, mock response = await a2a_agent.run("Submit task", background=True) assert response.continuation_token is not None - assert response.continuation_token["task_id"] == "task-sub" + token = cast(dict[str, Any], response.continuation_token) + assert token["task_id"] == "task-sub" async def test_input_required_task_emits_continuation_token( @@ -825,7 +832,8 @@ async def test_input_required_task_emits_continuation_token( response = await a2a_agent.run("Need input", background=True) assert response.continuation_token is not None - assert response.continuation_token["task_id"] == "task-input" + token = cast(dict[str, Any], response.continuation_token) + assert token["task_id"] == "task-input" async def test_working_task_no_token_without_background(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -879,8 +887,8 @@ async def test_non_streaming_run_uses_non_streaming_client() -> None: non_streaming_client = MockA2AClient() non_streaming_client.add_task_response("task-ns", [{"id": "art-1", "content": "Non-streaming result"}]) - agent = A2AAgent(name="Test Agent", id="test-ns", client=streaming_client, http_client=None) - agent._non_streaming_client = non_streaming_client # type: ignore[assignment] + agent = A2AAgent(name="Test Agent", id="test-ns", client=cast(Any, streaming_client), http_client=None) + agent._non_streaming_client = cast(Any, non_streaming_client) # type: ignore[assignment] response = await agent.run("Hello") @@ -897,8 +905,8 @@ async def test_streaming_run_uses_streaming_client() -> None: non_streaming_client = MockA2AClient() streaming_client.add_task_response("task-s", [{"id": "art-1", "content": "Streaming result"}]) - agent = A2AAgent(name="Test Agent", id="test-s", client=streaming_client, http_client=None) - agent._non_streaming_client = non_streaming_client # type: ignore[assignment] + agent = A2AAgent(name="Test Agent", id="test-s", client=cast(Any, streaming_client), http_client=None) + agent._non_streaming_client = cast(Any, non_streaming_client) # type: ignore[assignment] updates: list[AgentResponseUpdate] = [] async for update in agent.run("Hello", stream=True): @@ -946,8 +954,9 @@ async def test_streaming_emits_continuation_token(a2a_agent: A2AAgent, mock_a2a_ assert len(updates) == 1 assert updates[0].continuation_token is not None - assert updates[0].continuation_token["task_id"] == "task-stream" - assert updates[0].continuation_token["context_id"] == "ctx-s" + token = cast(dict[str, Any], updates[0].continuation_token) + assert token["task_id"] == "task-stream" + assert token["context_id"] == "ctx-s" async def test_resume_via_continuation_token(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -994,7 +1003,8 @@ async def test_resume_streaming_via_continuation_token(a2a_agent: A2AAgent, mock # First update: in-progress with token, second: completed with content assert len(updates) == 2 assert updates[0].continuation_token is not None - assert updates[0].continuation_token["task_id"] == "task-rs" + continuation_payload = cast(dict[str, Any], updates[0].continuation_token) + assert continuation_payload["task_id"] == "task-rs" assert updates[1].continuation_token is None assert updates[1].contents[0].text == "Stream resumed" @@ -1008,7 +1018,8 @@ async def test_poll_task_in_progress(a2a_agent: A2AAgent, mock_a2a_client: MockA response = await a2a_agent.poll_task(token) assert response.continuation_token is not None - assert response.continuation_token["task_id"] == "task-poll" + response_token = cast(dict[str, Any], response.continuation_token) + assert response_token["task_id"] == "task-poll" async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -1040,7 +1051,7 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A @mark.asyncio async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_client: MockA2AClient) -> None: """Test that run() wires session.service_session_id to the A2A message context_id.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_message_response("msg-ctx", "reply") session = AgentSession(service_session_id="svc-session-42") @@ -1053,7 +1064,7 @@ async def test_run_passes_session_service_session_id_as_context_id(mock_a2a_clie @mark.asyncio async def test_run_a2a_session_context_id_used_over_service_session_id(mock_a2a_client: MockA2AClient) -> None: """Test that A2AAgentSession.context_id is used for outbound messages.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_message_response("msg-ctx2", "reply") session = A2AAgentSession(context_id="a2a-ctx-99") @@ -1107,7 +1118,7 @@ async def test_run_invokes_context_providers(mock_a2a_client: MockA2AClient) -> provider = TrackingContextProvider() agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), context_providers=[provider], http_client=None, ) @@ -1126,7 +1137,7 @@ async def test_run_streaming_invokes_context_providers(mock_a2a_client: MockA2AC provider = TrackingContextProvider() agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), context_providers=[provider], http_client=None, ) @@ -1149,7 +1160,7 @@ async def test_context_providers_receive_response(mock_a2a_client: MockA2AClient provider = TrackingContextProvider() agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), context_providers=[provider], http_client=None, ) @@ -1168,7 +1179,7 @@ async def test_context_providers_receive_input_messages(mock_a2a_client: MockA2A provider = TrackingContextProvider() agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), context_providers=[provider], http_client=None, ) @@ -1186,7 +1197,7 @@ async def test_run_without_context_providers(mock_a2a_client: MockA2AClient) -> """Test that run works normally when no context providers are configured.""" agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), http_client=None, ) mock_a2a_client.add_message_response("msg-1", "Hello") @@ -1201,7 +1212,7 @@ async def test_run_creates_session_for_providers_when_none_provided(mock_a2a_cli provider = TrackingContextProvider() agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), context_providers=[provider], http_client=None, ) @@ -1220,7 +1231,7 @@ async def test_run_raises_when_no_messages_and_no_continuation_token( """Test that run() raises ValueError when messages is None/empty and no continuation_token is provided.""" agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), http_client=None, ) @@ -1239,7 +1250,7 @@ async def test_run_with_continuation_token_does_not_require_messages(mock_a2a_cl agent = A2AAgent( name="Test Agent", - client=mock_a2a_client, + client=cast(Any, mock_a2a_client), http_client=None, ) @@ -1331,7 +1342,8 @@ async def test_background_with_status_message_yields_continuation_token( assert len(updates) == 1 assert updates[0].continuation_token is not None - assert updates[0].continuation_token["task_id"] == "task-bg" + token = cast(dict[str, Any], updates[0].continuation_token) + assert token["task_id"] == "task-bg" assert updates[0].contents == [] @@ -1701,8 +1713,10 @@ async def test_task_artifact_update_event_metadata_merged(a2a_agent: A2AAgent, m updates.append(update) artifact_update = updates[0] - assert artifact_update.additional_properties["a2a_metadata"]["from_artifact"] is True - assert artifact_update.additional_properties["a2a_metadata"]["from_event"] is True + assert artifact_update.additional_properties is not None + metadata = artifact_update.additional_properties["a2a_metadata"] + assert metadata["from_artifact"] is True + assert metadata["from_event"] is True async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -1729,8 +1743,10 @@ async def test_task_status_update_event_metadata_merged(a2a_agent: A2AAgent, moc updates.append(update) status_update = updates[0] - assert status_update.additional_properties["a2a_metadata"]["msg_key"] == "msg_val" - assert status_update.additional_properties["a2a_metadata"]["event_key"] == "event_val" + assert status_update.additional_properties is not None + metadata = status_update.additional_properties["a2a_metadata"] + assert metadata["msg_key"] == "msg_val" + assert metadata["event_key"] == "event_val" async def test_history_message_metadata_propagated(a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient) -> None: @@ -1890,7 +1906,7 @@ async def test_non_streaming_artifact_update_surfaces_content( @mark.asyncio async def test_first_message_has_no_reference_task_ids(mock_a2a_client: MockA2AClient) -> None: """Test that the first message sent has no reference_task_ids.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-first", [{"content": "Hello back"}]) session = A2AAgentSession() @@ -1903,7 +1919,7 @@ async def test_first_message_has_no_reference_task_ids(mock_a2a_client: MockA2AC @mark.asyncio async def test_follow_up_message_includes_reference_task_ids(mock_a2a_client: MockA2AClient) -> None: """Test that a follow-up message references the previous task_id.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-abc-123", [{"content": "First reply"}]) session = A2AAgentSession() @@ -1923,7 +1939,7 @@ async def test_follow_up_message_includes_reference_task_ids(mock_a2a_client: Mo @mark.asyncio async def test_reference_task_ids_updated_after_each_interaction(mock_a2a_client: MockA2AClient) -> None: """Test that reference_task_ids always points to the most recent task.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) session = A2AAgentSession() @@ -1948,7 +1964,7 @@ async def test_reference_task_ids_updated_after_each_interaction(mock_a2a_client @mark.asyncio async def test_task_id_tracked_from_status_update_events(mock_a2a_client: MockA2AClient) -> None: """Test that task_id is tracked even when response only contains status update events.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # Simulate a stream that only has status_update events (no full task payload) status_event = TaskStatusUpdateEvent( @@ -1975,7 +1991,7 @@ async def test_task_id_tracked_from_status_update_events(mock_a2a_client: MockA2 @mark.asyncio async def test_no_session_does_not_crash_reference_task_ids(mock_a2a_client: MockA2AClient) -> None: """Test that running without a session (no reference tracking) works fine.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-no-session", [{"content": "Reply"}]) # Should not raise — no session means no reference_task_ids @@ -1987,7 +2003,7 @@ async def test_no_session_does_not_crash_reference_task_ids(mock_a2a_client: Moc @mark.asyncio async def test_task_id_not_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None: """Test that task_id is NOT tracked from message payloads (simple interactions without task tracking).""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # Simulate a response that is a message with task_id set (no task/status_update events). # Per A2A spec, a Message response indicates simple interaction — task_id should not be persisted. @@ -2008,7 +2024,7 @@ async def test_task_id_not_tracked_from_message_payload(mock_a2a_client: MockA2A @mark.asyncio async def test_context_id_assigned_from_response(mock_a2a_client: MockA2AClient) -> None: """Test that context_id is assigned from the response when not set on session.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-ctx", [{"content": "Reply"}]) session = A2AAgentSession() @@ -2024,7 +2040,7 @@ async def test_context_id_assigned_from_response(mock_a2a_client: MockA2AClient) @mark.asyncio async def test_context_id_tracked_from_message_payload(mock_a2a_client: MockA2AClient) -> None: """Test that context_id is captured from message-only responses (no task payload).""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # Simulate a response with only a message that has context_id but no task_id message_with_context = A2AMessage( @@ -2047,7 +2063,7 @@ async def test_context_id_tracked_from_message_payload(mock_a2a_client: MockA2AC @mark.asyncio async def test_context_id_mismatch_raises_error(mock_a2a_client: MockA2AClient) -> None: """Test that a context_id mismatch between session and response raises an error.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # Task response has context_id="test-context" (from add_task_response helper) mock_a2a_client.add_task_response("task-mismatch", [{"content": "Reply"}]) @@ -2062,7 +2078,7 @@ async def test_context_id_mismatch_raises_error(mock_a2a_client: MockA2AClient) @mark.asyncio async def test_task_state_tracked_on_session(mock_a2a_client: MockA2AClient) -> None: """Test that task_state is tracked on A2AAgentSession.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # Add a task that ends in INPUT_REQUIRED mock_a2a_client.add_in_progress_task_response( @@ -2082,7 +2098,7 @@ async def test_task_state_tracked_on_session(mock_a2a_client: MockA2AClient) -> @mark.asyncio async def test_plain_agent_session_no_reference_tracking(mock_a2a_client: MockA2AClient) -> None: """Test that a plain AgentSession works but does not get reference_task_ids tracking.""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) mock_a2a_client.add_task_response("task-plain", [{"content": "Reply"}]) session = AgentSession() @@ -2118,7 +2134,7 @@ async def test_a2a_agent_session_serialization() -> None: @mark.asyncio async def test_input_required_sets_task_id_instead_of_reference(mock_a2a_client: MockA2AClient) -> None: """Test that when task_state is INPUT_REQUIRED, follow-up sets task_id (not reference_task_ids).""" - agent = A2AAgent(name="Test Agent", id="test-agent", client=mock_a2a_client, http_client=None) + agent = A2AAgent(name="Test Agent", id="test-agent", client=cast(Any, mock_a2a_client), http_client=None) # First turn: task ends in INPUT_REQUIRED mock_a2a_client.add_in_progress_task_response( diff --git a/python/packages/a2a/tests/test_a2a_executor.py b/python/packages/a2a/tests/test_a2a_executor.py index 27a2aed1ee5..368d0918bed 100644 --- a/python/packages/a2a/tests/test_a2a_executor.py +++ b/python/packages/a2a/tests/test_a2a_executor.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. from asyncio import CancelledError +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -199,8 +200,8 @@ async def test_execute_with_existing_task_succeeds( response_message = Message(role="assistant", contents=[Content.from_text(text="Hello back")]) response = MagicMock(spec=AgentResponse) response.messages = [response_message] - executor._agent.run = AsyncMock(return_value=response) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor._agent).create_session = MagicMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -218,8 +219,8 @@ async def test_execute_with_existing_task_succeeds( mock_updater.submit.assert_called_once() mock_updater.start_work.assert_called_once() mock_updater.complete.assert_called_once() - executor._agent.create_session.assert_called_once() - executor._agent.run.assert_called_once() + cast(Any, executor._agent.create_session).assert_called_once() + cast(Any, executor._agent.run).assert_called_once() async def test_execute_creates_task_when_not_exists( self, @@ -241,8 +242,8 @@ async def test_execute_creates_task_when_not_exists( response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) response = MagicMock(spec=AgentResponse) response.messages = [response_message] - executor._agent.run = AsyncMock(return_value=response) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor._agent).create_session = MagicMock() with patch("agent_framework_a2a._a2a_executor.new_task_from_user_message") as mock_new_task: mock_task = MagicMock(spec=Task) @@ -325,8 +326,8 @@ async def test_execute_handles_cancelled_error( mock_request_context.context_id = "ctx-123" mock_request_context.message = MagicMock() - executor._agent.run = AsyncMock(side_effect=CancelledError()) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(side_effect=CancelledError()) + cast(Any, executor._agent).create_session = MagicMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -361,8 +362,8 @@ async def test_execute_handles_generic_exception( mock_request_context.message = MagicMock() error_message = "Test error" - executor._agent.run = AsyncMock(side_effect=ValueError(error_message)) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(side_effect=ValueError(error_message)) + cast(Any, executor._agent).create_session = MagicMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -410,11 +411,11 @@ async def test_execute_processes_multiple_response_messages( response_message2 = Message(role="assistant", contents=[Content.from_text(text="Second")]) response = MagicMock(spec=AgentResponse) response.messages = [response_message1, response_message2] - executor._agent.run = AsyncMock(return_value=response) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor._agent).create_session = MagicMock() # Mock handle_events - executor.handle_events = AsyncMock() + cast(Any, executor).handle_events = AsyncMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -428,7 +429,7 @@ async def test_execute_processes_multiple_response_messages( await executor.execute(mock_request_context, mock_event_queue) # Assert - assert executor.handle_events.call_count == 2 + assert cast(Any, executor.handle_events).call_count == 2 async def test_execute_passes_query_to_run( self, @@ -451,8 +452,8 @@ async def test_execute_passes_query_to_run( response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) response = MagicMock(spec=AgentResponse) response.messages = [response_message] - executor._agent.run = AsyncMock(return_value=response) - executor._agent.create_session = MagicMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor._agent).create_session = MagicMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -467,7 +468,7 @@ async def test_execute_passes_query_to_run( await executor.execute(mock_request_context, mock_event_queue) # Assert - executor._agent.run.assert_called_once_with( + cast(Any, executor._agent.run).assert_called_once_with( query_text, session=executor._agent.create_session(), stream=False ) @@ -566,14 +567,14 @@ async def test_run_method_with_single_message(self, executor: A2AExecutor, mock_ response_message = Message(role="assistant", contents=[Content.from_text(text="Response")]) response = MagicMock(spec=AgentResponse) response.messages = response_message # Not a list - executor._agent.run = AsyncMock(return_value=response) - executor.handle_events = AsyncMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor).handle_events = AsyncMock() # Act await executor._run(query, session, mock_updater) # Assert - executor.handle_events.assert_called_once_with(response_message, mock_updater) + cast(Any, executor.handle_events).assert_called_once_with(response_message, mock_updater) @fixture def mock_updater(self) -> MagicMock: @@ -808,7 +809,7 @@ async def test_handle_agent_response_update_first_time( message_id="msg-1", ) mock_updater.add_artifact = AsyncMock() - streamed_artifact_ids = set() + streamed_artifact_ids: set[str] = set() # Act await executor.handle_events(update, mock_updater, streamed_artifact_ids=streamed_artifact_ids) @@ -844,7 +845,7 @@ async def test_handle_unsupported_content_type(self, executor: A2AExecutor, mock """Test handling messages with unsupported content types.""" # Arrange message = Message( - contents=[Content(type="unknown", text="Some text")], + contents=[Content(type=cast(Any, "unknown"), text="Some text")], # type: ignore[arg-type] role="assistant", ) @@ -884,9 +885,9 @@ async def test_full_execution_flow_with_responses( response_message.role = "assistant" response_message.additional_properties = None - executor._agent.run = AsyncMock(return_value=response) - executor._agent.create_session = MagicMock() - executor.handle_events = AsyncMock() + cast(Any, executor._agent).run = AsyncMock(return_value=response) + cast(Any, executor._agent).create_session = MagicMock() + cast(Any, executor).handle_events = AsyncMock() with patch("agent_framework_a2a._a2a_executor.TaskUpdater") as mock_updater_class: mock_updater = MagicMock() @@ -902,5 +903,5 @@ async def test_full_execution_flow_with_responses( # Assert mock_updater.submit.assert_called_once() mock_updater.start_work.assert_called_once() - executor.handle_events.assert_called_once() + cast(Any, executor.handle_events).assert_called_once() mock_updater.complete.assert_called_once() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index ecde5a67e17..38accb8b9e0 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -53,7 +53,7 @@ def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: base_model_type = None if base_model_type is not None and isinstance(state_schema, base_model_type): - schema_dict = state_schema.__class__.model_json_schema() # type: ignore[union-attr] + schema_dict = state_schema.__class__.model_json_schema() return schema_dict.get("properties", {}) or {} if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type): diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 38578f1bf2b..96878ea6994 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -591,7 +591,7 @@ async def _resolve_approval_responses( Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") ) - _replace_approval_contents_with_results(messages, fcc_todo, approved_results) # type: ignore + _replace_approval_contents_with_results(messages, fcc_todo, approved_results) # Post-process: Convert user messages with function_result content to proper tool messages. # After _replace_approval_contents_with_results, approved tool calls have their results diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 7a1b974a383..ae81162b6d5 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -32,13 +32,13 @@ from ._utils import convert_tools_to_agui_format if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): from typing import Self, TypedDict # pragma: no cover else: @@ -91,7 +91,7 @@ async def _response_wrapper_impl(self, original_func: Any, *args: Any, **kwargs: if response.messages: for message in response.messages: _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents)) - return response # type: ignore[no-any-return] + return response async def _stream_wrapper_impl(stream: Any) -> AsyncIterable[ChatResponseUpdate]: """Streaming wrapper implementation.""" @@ -273,7 +273,7 @@ def _register_server_tool_placeholder(self, tool_name: str) -> None: config["additional_tools"] = additional_tools registered: set[str] = getattr(self, "_registered_server_tools", set()) registered.add(tool_name) - self._registered_server_tools = registered # type: ignore[attr-defined] + self._registered_server_tools = registered logger.debug(f"[AGUIChatClient] Registered server placeholder: {tool_name}") def _extract_state_from_messages(self, messages: Sequence[Message]) -> tuple[list[Message], dict[str, Any] | None]: @@ -414,8 +414,8 @@ async def _streaming_impl( if tools: for tool in tools: if hasattr(tool, "name"): - client_tool_set.add(tool.name) # type: ignore[arg-type] - self._last_client_tool_set = client_tool_set # type: ignore[attr-defined] + client_tool_set.add(tool.name) + self._last_client_tool_set = client_tool_set logger.debug( "[AGUIChatClient] Preparing request", @@ -451,18 +451,18 @@ async def _streaming_impl( ) # Distinguish client vs server tools for i, content in enumerate(update.contents): - if content.type == "function_call": # type: ignore[attr-defined] + if content.type == "function_call": logger.debug( - f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" # type: ignore[attr-defined] + f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" ) - if content.name in client_tool_set: # type: ignore[attr-defined] + if content.name in client_tool_set: # Client tool - let function invocation execute it - if not content.additional_properties: # type: ignore[attr-defined] - content.additional_properties = {} # type: ignore[attr-defined] - content.additional_properties["agui_thread_id"] = thread_id # type: ignore[attr-defined] + if not content.additional_properties: + content.additional_properties = {} + content.additional_properties["agui_thread_id"] = thread_id else: # Server tool - wrap so function invocation ignores it - logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") # type: ignore[union-attr] + logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") self._register_server_tool_placeholder(content.name) # type: ignore[arg-type] update.contents[i] = Content(type="server_function_call", function_call=content) # type: ignore diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index c4d2e9b2cd3..47d5841d576 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -118,13 +118,13 @@ def _sanitize_tool_history(messages: list[Message]) -> list[Message]: user_text = "" for content in msg.contents or []: if content.type == "text": - user_text = content.text # type: ignore[assignment] + user_text = content.text break if not user_text: continue try: - parsed = json.loads(user_text) # type: ignore[arg-type] + parsed = json.loads(user_text) if "accepted" in parsed: logger.info( f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" @@ -843,14 +843,14 @@ def _filter_modified_args( ) approval_contents.append(approval_response) - chat_msg = Message(role=role, contents=approval_contents) # type: ignore[call-overload] + chat_msg = Message(role=role, contents=approval_contents) else: # Regular message content (text or multimodal) content = msg.get("content", "") converted_contents = _convert_agui_content_to_framework(content) if not converted_contents: converted_contents = [Content.from_text(text="")] - chat_msg = Message(role=role, contents=converted_contents) # type: ignore[call-overload] + chat_msg = Message(role=role, contents=converted_contents) if "id" in msg: chat_msg.message_id = msg["id"] @@ -894,7 +894,7 @@ def agent_framework_messages_to_agui(messages: list[Message] | list[dict[str, An continue # Convert Message to AG-UI format - role_value: str = msg.role if hasattr(msg.role, "value") else msg.role # type: ignore[assignment] + role_value: str = msg.role if hasattr(msg.role, "value") else msg.role role = FRAMEWORK_TO_AGUI_ROLE.get(role_value, "user") content_text = "" diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index c8ee728df3a..d105c3cb945 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -9,13 +9,13 @@ from pydantic import AliasChoices, BaseModel, Field if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover AGUIChatOptionsT = TypeVar("AGUIChatOptionsT", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index db98e6bfc3f..777f8aa29a3 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -162,11 +162,11 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401 # asdict may return nested non-dataclass objects, so recursively make them safe return make_json_safe(asdict(obj)) # type: ignore[arg-type] if hasattr(obj, "model_dump"): - return make_json_safe(obj.model_dump()) # type: ignore[no-any-return] + return make_json_safe(obj.model_dump()) if hasattr(obj, "to_dict"): - return make_json_safe(obj.to_dict()) # type: ignore[no-any-return] + return make_json_safe(obj.to_dict()) if hasattr(obj, "dict"): - return make_json_safe(obj.dict()) # type: ignore[no-any-return] + return make_json_safe(obj.dict()) if hasattr(obj, "__dict__"): return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc] if isinstance(obj, (list, tuple)): diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index fab6bb210cf..44d571aef8d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -142,12 +142,12 @@ def _extract_responses_from_messages(messages: list[Message]) -> dict[str, Any]: elif content.type == "function_approval_response" and getattr(content, "id", None): approval_value: dict[str, Any] = { "approved": getattr(content, "approved", False), - "id": str(content.id), # type: ignore[union-attr] + "id": str(content.id), } func_call = getattr(content, "function_call", None) if func_call is not None: approval_value["function_call"] = make_json_safe(func_call.to_dict()) - responses[str(content.id)] = approval_value # type: ignore[union-attr] + responses[str(content.id)] = approval_value return responses diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py index 2b82b7706dd..b3c02e57a0b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py @@ -212,7 +212,7 @@ async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[Any]: logger.info("Making SECOND LLM call to generate summary after step execution") # Get the underlying chat agent and client - chat_agent = self._base_agent.agent # type: ignore + chat_agent = self._base_agent.agent client = chat_agent.client # type: ignore # Build messages for summary call diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py index 9ea92dd24e1..b772d91c7c8 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py @@ -11,13 +11,13 @@ from agent_framework.ag_ui import AgentFrameworkAgent if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from agent_framework import ChatOptions diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index 4b7d56fba54..6425f00a29a 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -124,7 +124,7 @@ # Agentic Generative UI - task steps agent with streaming state updates add_agent_framework_fastapi_endpoint( app=app, - agent=task_steps_agent_wrapped(client), # type: ignore[arg-type] + agent=task_steps_agent_wrapped(client), path="/agentic_generative_ui", ) diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 64ac8e9d663..e68b395bf22 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence from pathlib import Path from types import SimpleNamespace -from typing import Any, Generic, Literal, cast, overload +from typing import Any, Generic, Literal, TypedDict, cast, overload # noqa: F401 import pytest from agent_framework import ( @@ -104,14 +104,18 @@ def get_response( else: self.last_session = None self.last_service_session_id = self.last_session.service_session_id if self.last_session else None - return cast( - Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - super().get_response( + if stream: + return super().get_response( messages=messages, - stream=cast(Literal[True, False], stream), + stream=True, options=options, **kwargs, - ), + ) + return super().get_response( + messages=messages, + stream=False, + options=options, + **kwargs, ) @override @@ -216,7 +220,12 @@ def run( if stream: async def _stream() -> AsyncIterator[AgentResponseUpdate]: - self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] + if messages is None: + self.messages_received = [] + elif isinstance(messages, (str, Content, Message)): + self.messages_received = [messages] + else: + self.messages_received = list(messages) self.last_session = session self.tools_received = kwargs.get("tools") for update in self.updates: @@ -233,7 +242,10 @@ async def _get_response() -> AgentResponse[Any]: return _get_response() def create_session(self, **kwargs: Any) -> AgentSession: - return AgentSession() + return AgentSession(session_id=kwargs.get("session_id")) + + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + return AgentSession(session_id=session_id, service_session_id=service_session_id) # Fixtures diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py index 00516171c25..1e5828425d7 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py @@ -7,8 +7,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py index 7b48740cad9..c16bfe1fd11 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py @@ -7,8 +7,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py index 141116c3484..7c25147dcb8 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_deterministic_state.py @@ -14,8 +14,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent, state_update diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py index 211bbeedc64..60c5ba4ac5b 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py @@ -7,8 +7,7 @@ from typing import Any from agent_framework import WorkflowBuilder, WorkflowContext, executor -from event_stream import EventStream -from typing_extensions import Never +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkWorkflow @@ -31,7 +30,7 @@ async def test_workflow_agent_golden_sequence() -> None: """Workflow-as-agent: emits step events and text content.""" @executor(id="generator") - async def generator(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def generator(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("Here is your generated UI content!") workflow = WorkflowBuilder(start_executor=generator).build() @@ -54,7 +53,7 @@ async def test_workflow_agent_step_names_match() -> None: """Step started/finished events reference the executor name.""" @executor(id="my_executor") - async def my_executor(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def my_executor(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("Done!") workflow = WorkflowBuilder(start_executor=my_executor).build() @@ -71,7 +70,7 @@ async def test_workflow_agent_ordered_events() -> None: """Workflow events follow expected ordering: RUN_STARTED → STEP_STARTED → content → STEP_FINISHED → RUN_FINISHED.""" @executor(id="my_step") - async def my_step(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def my_step(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("Generated content") workflow = WorkflowBuilder(start_executor=my_step).build() diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py index b154b53236e..fbf2e92b1b7 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py @@ -7,8 +7,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py index 7af256f625a..2570bd5c313 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py @@ -8,8 +8,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py index 3870e007285..4ca663c337e 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py @@ -7,8 +7,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py index efbe34ed8f1..584d0552802 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py @@ -7,8 +7,8 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent -from event_stream import EventStream +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from pydantic import BaseModel from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py index 61e89057fbe..594bfea7763 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py @@ -11,7 +11,7 @@ import json from typing import Any -from event_stream import EventStream +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui_examples.agents.subgraphs_agent import subgraphs_agent diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py index 5f13b8e67fa..81669654e5a 100644 --- a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py @@ -30,8 +30,7 @@ handler, response_handler, ) -from event_stream import EventStream -from typing_extensions import Never +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkWorkflow @@ -59,7 +58,7 @@ async def test_workflow_text_output_golden_sequence() -> None: """Simple text output: RUN_STARTED → STEP_STARTED → TEXT_* → STEP_FINISHED → TEXT_MESSAGE_END → RUN_FINISHED.""" @executor(id="greeter") - async def greeter(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def greeter(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("Hello from workflow!") workflow = WorkflowBuilder(start_executor=greeter).build() @@ -82,7 +81,7 @@ async def test_workflow_text_output_message_id_consistency() -> None: """All text events for a single output share the same message_id.""" @executor(id="echo") - async def echo(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def echo(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("echo reply") workflow = WorkflowBuilder(start_executor=echo).build() @@ -101,7 +100,7 @@ async def test_workflow_executor_lifecycle_events() -> None: """Executor invocation produces STEP_STARTED, ACTIVITY_SNAPSHOT, STEP_FINISHED.""" @executor(id="worker") - async def worker(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def worker(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("done") workflow = WorkflowBuilder(start_executor=worker).build() @@ -126,7 +125,7 @@ async def test_workflow_executor_step_ordering() -> None: """STEP_STARTED comes before content, STEP_FINISHED comes after.""" @executor(id="orderer") - async def orderer(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def orderer(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("ordered output") workflow = WorkflowBuilder(start_executor=orderer).build() @@ -154,7 +153,7 @@ async def test_workflow_dict_output_maps_to_custom_event() -> None: """Non-chat dict output is emitted as CUSTOM workflow_output event.""" @executor(id="structured") - async def structured(message: Any, ctx: WorkflowContext[Never, dict[str, int]]) -> None: + async def structured(message: Any, ctx: WorkflowContext[Any, dict[str, int]]) -> None: await ctx.yield_output({"count": 42, "status": 1}) workflow = WorkflowBuilder(start_executor=structured).build() @@ -181,7 +180,7 @@ async def test_workflow_base_event_passthrough() -> None: """AG-UI BaseEvent outputs are yielded directly, not wrapped.""" @executor(id="stateful") - async def stateful(message: Any, ctx: WorkflowContext[Never, StateSnapshotEvent]) -> None: + async def stateful(message: Any, ctx: WorkflowContext[Any, StateSnapshotEvent]) -> None: await ctx.yield_output(StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot={"active_agent": "flights"})) workflow = WorkflowBuilder(start_executor=stateful).build() @@ -203,7 +202,7 @@ async def test_workflow_agent_response_output_extracts_latest_assistant() -> Non """AgentResponse output uses only the latest assistant message, not full history.""" @executor(id="responder") - async def responder(message: Any, ctx: WorkflowContext[Never, AgentResponse]) -> None: + async def responder(message: Any, ctx: WorkflowContext[Any, AgentResponse]) -> None: response = AgentResponse( messages=[ Message(role="user", contents=[Content.from_text("My order is damaged")]), @@ -232,14 +231,14 @@ class ProgressEvent(WorkflowEvent): """Custom workflow event for testing CUSTOM event mapping.""" def __init__(self, progress: int) -> None: - super().__init__("custom_progress", data={"progress": progress}) + super().__init__(cast(Any, "custom_progress"), data={"progress": progress}) async def test_workflow_custom_events() -> None: """Custom workflow events are mapped to CUSTOM AG-UI events.""" @executor(id="progress_tracker") - async def progress_tracker(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def progress_tracker(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.add_event(ProgressEvent(25)) await ctx.yield_output("In progress...") await ctx.add_event(ProgressEvent(100)) @@ -355,7 +354,7 @@ async def test_workflow_text_drained_before_request_info() -> None: @executor(id="text_then_request") async def text_then_request(message: Any, ctx: WorkflowContext) -> None: - await ctx.yield_output("Please confirm this action.") + await ctx.yield_output("Please confirm this action.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] await ctx.request_info("Need approval", str, request_id="approval-1") workflow = WorkflowBuilder(start_executor=text_then_request).build() @@ -383,7 +382,7 @@ async def test_workflow_skips_duplicate_text_from_snapshot() -> None: """Duplicate text from AgentResponse snapshot is not re-emitted.""" @executor(id="deduper") - async def deduper(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + async def deduper(message: Any, ctx: WorkflowContext[Any, Any]) -> None: text = "Order processed successfully." await ctx.yield_output(text) # Snapshot repeats the same text @@ -410,7 +409,7 @@ async def test_workflow_skips_consecutive_duplicate_outputs() -> None: """Consecutive identical text outputs are deduplicated.""" @executor(id="repeater") - async def repeater(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + async def repeater(message: Any, ctx: WorkflowContext[Any, Any]) -> None: text = "Done!" await ctx.yield_output(text) await ctx.yield_output(text) @@ -428,7 +427,7 @@ async def test_workflow_emits_distinct_consecutive_outputs() -> None: """Distinct text outputs are all emitted, not incorrectly deduplicated.""" @executor(id="multisayer") - async def multisayer(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def multisayer(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("First part. ") await ctx.yield_output("Second part.") @@ -505,7 +504,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: @response_handler async def handle_choice(self, original: str, response: str, ctx: WorkflowContext) -> None: - await ctx.yield_output(f"You chose: {response}") + await ctx.yield_output(f"You chose: {response}") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=RequesterExecutor()).build() wrapper = AgentFrameworkWorkflow(workflow=workflow) @@ -620,7 +619,7 @@ async def test_workflow_empty_turn_no_pending_requests() -> None: """Empty turn with no pending requests produces clean bookends.""" @executor(id="noop") - async def noop(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def noop(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("done") workflow = WorkflowBuilder(start_executor=noop).build() @@ -651,7 +650,7 @@ async def test_workflow_usage_output_maps_to_custom_event() -> None: """Usage Content outputs are surfaced as custom usage events.""" @executor(id="usage_reporter") - async def usage_reporter(message: Any, ctx: WorkflowContext[Never, Content]) -> None: + async def usage_reporter(message: Any, ctx: WorkflowContext[Any, Content]) -> None: await ctx.yield_output( Content.from_usage({"input_token_count": 100, "output_token_count": 50, "total_token_count": 150}) ) @@ -694,7 +693,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: @response_handler async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: status = "approved" if bool(response.approved) else "rejected" - await ctx.yield_output(f"Refund {status}.") + await ctx.yield_output(f"Refund {status}.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() wrapper = AgentFrameworkWorkflow(workflow=workflow) @@ -762,7 +761,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: @response_handler async def handle_input(self, original: dict, response: list[Message], ctx: WorkflowContext) -> None: user_text = response[0].text if response else "" - await ctx.yield_output(f"Got: {user_text}") + await ctx.yield_output(f"Got: {user_text}") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=MessageRequestExecutor()).build() wrapper = AgentFrameworkWorkflow(workflow=workflow) @@ -861,7 +860,7 @@ async def test_workflow_factory_thread_scoping() -> None: def make_workflow(thread_id: str): @executor(id="echo") - async def echo(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def echo(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output(f"Thread: {thread_id}") return WorkflowBuilder(start_executor=echo).build() @@ -914,7 +913,7 @@ async def start(self, message: str, ctx: WorkflowContext[str]) -> None: @response_handler async def handle_dest(self, original: str, response: str, ctx: WorkflowContext[str]) -> None: - await ctx.yield_output(f"Booking for {self._name} to {response}") + await ctx.yield_output(f"Booking for {self._name} to {response}") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name_requester = NameRequester() dest_requester = DestRequester() diff --git a/python/packages/ag-ui/tests/ag_ui/sse_helpers.py b/python/packages/ag-ui/tests/ag_ui/sse_helpers.py index 8a71dd9afba..acb6f46dd9a 100644 --- a/python/packages/ag-ui/tests/ag_ui/sse_helpers.py +++ b/python/packages/ag-ui/tests/ag_ui/sse_helpers.py @@ -7,7 +7,7 @@ import json from typing import Any -from event_stream import EventStream +from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] def parse_sse_response(response_content: bytes) -> list[dict[str, Any]]: diff --git a/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py b/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py index d473e5beccc..a04d13ab1c1 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py @@ -37,15 +37,19 @@ def convert_messages_to_agui_format(self, messages: list[Message]) -> list[dict[ """Expose message conversion helper.""" return self._convert_messages_to_agui_format(messages) - def get_thread_id(self, options: dict[str, Any]) -> str: + def get_thread_id(self, options: ChatOptions[Any] | dict[str, Any] | None) -> str: """Expose thread id helper.""" - return self._get_thread_id(options) + return self._get_thread_id(options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] def inner_get_response( - self, *, messages: MutableSequence[Message], options: dict[str, Any], stream: bool = False + self, + *, + messages: MutableSequence[Message], + options: ChatOptions[Any] | dict[str, Any] | None, + stream: bool = False, ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: """Proxy to protected response call.""" - return self._inner_get_response(messages=messages, options=options, stream=stream) + return self._inner_get_response(messages=messages, options=options, stream=stream) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] class TestAGUIChatClient: @@ -177,7 +181,9 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str chat_options = ChatOptions() updates: list[ChatResponseUpdate] = [] - async for update in client._inner_get_response(messages=messages, stream=True, options=chat_options): + stream = client.inner_get_response(messages=messages, stream=True, options=chat_options) + assert isinstance(stream, ResponseStream) + async for update in stream: updates.append(update) assert len(updates) == 4 @@ -207,7 +213,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", contents=["Test message"])] - chat_options = {} + chat_options: dict[str, Any] = {} response = await client.inner_get_response(messages=messages, options=chat_options) @@ -418,7 +424,9 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str messages = [Message(role="user", contents=["Test"])] updates: list[ChatResponseUpdate] = [] - async for update in client._inner_get_response(messages=messages, stream=True, options={"tools": [my_tool]}): + stream = client.inner_get_response(messages=messages, stream=True, options={"tools": [my_tool]}) + assert isinstance(stream, ResponseStream) + async for update in stream: updates.append(update) # Find the function_call content - it should have agui_thread_id diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index b4b8fa04a70..f6272ab86e8 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -695,7 +695,7 @@ async def stream_fn( events: list[Any] = [] async for event in wrapper.run(input_data): events.append(event) - assert request_service_session_id is None # type: ignore[attr-defined] (service_session_id should be set) + assert request_service_session_id is None # type: ignore[attr-defined] # service_session_id should be set async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub): @@ -724,7 +724,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: captured_service_session_id = session.service_session_id if session else None return original_run(*args, **kwargs) - agent.run = capturing_run # type: ignore[assignment, method-assign] + agent.run = capturing_run # type: ignore[assignment, method-assign] # ty: ignore[invalid-assignment] events: list[Any] = [] async for event in wrapper.run(input_data): diff --git a/python/packages/ag-ui/tests/ag_ui/test_approval_result_event.py b/python/packages/ag-ui/tests/ag_ui/test_approval_result_event.py index b83eec10d24..ab9a10d8498 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_approval_result_event.py +++ b/python/packages/ag-ui/tests/ag_ui/test_approval_result_event.py @@ -8,7 +8,7 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content, FunctionTool -from conftest import StubAgent +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._agent_run import run_agent_stream diff --git a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py index 51ab468b84c..1d1a63b53a3 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py +++ b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py @@ -70,7 +70,7 @@ async def test_add_endpoint_with_workflow_protocol(): @executor(id="start") async def start(message: Any, ctx: WorkflowContext) -> None: - await ctx.yield_output("Workflow response") + await ctx.yield_output("Workflow response") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] app = FastAPI() workflow = WorkflowBuilder(start_executor=start).build() @@ -209,7 +209,7 @@ async def test_endpoint_with_workflow_as_agent_stream_output(build_chat_client): reviewer_agent = Agent(name="reviewer", instructions="Review ideas", client=build_chat_client("Review")) agent = SequentialBuilder(participants=[brainstorm_agent, reviewer_agent]).build().as_agent() - add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like") + add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] client = TestClient(app) response = client.post("/workflow-like", json={"messages": [{"role": "user", "content": "Hello"}]}) @@ -557,7 +557,7 @@ async def test_endpoint_invalid_agent_type_raises_typeerror(): app = FastAPI() with pytest.raises(TypeError, match="must be SupportsAgentRun"): - add_agent_framework_fastapi_endpoint(app, agent="not_an_agent") # type: ignore[arg-type] + add_agent_framework_fastapi_endpoint(app, agent="not_an_agent") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] async def test_endpoint_encoding_failure_emits_run_error(): diff --git a/python/packages/ag-ui/tests/ag_ui/test_event_converters.py b/python/packages/ag-ui/tests/ag_ui/test_event_converters.py index 70bd4a0f04b..4b207c721cc 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_event_converters.py +++ b/python/packages/ag-ui/tests/ag_ui/test_event_converters.py @@ -2,6 +2,8 @@ """Tests for AG-UI event converter.""" +from typing import Any, cast + from agent_framework_ag_ui._event_converters import AGUIEventConverter @@ -20,6 +22,8 @@ def test_run_started_event(self) -> None: update = converter.convert_event(event) assert update is not None + assert update.additional_properties is not None + assert update.additional_properties is not None assert update.role == "assistant" assert update.additional_properties["thread_id"] == "thread_123" assert update.additional_properties["run_id"] == "run_456" @@ -67,7 +71,7 @@ def test_text_message_streaming(self) -> None: {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "!"}, ] - updates = [converter.convert_event(event) for event in events] + updates = cast(list[Any], [converter.convert_event(event) for event in events]) assert all(update is not None for update in updates) assert all(update.message_id == "msg_1" for update in updates) @@ -148,7 +152,7 @@ def test_tool_call_args_streaming(self) -> None: {"type": "TOOL_CALL_ARGS", "delta": 'latest news"}'}, ] - updates = [converter.convert_event(event) for event in events] + updates = cast(list[Any], [converter.convert_event(event) for event in events]) assert all(update is not None for update in updates) assert updates[0].contents[0].arguments == '{"query": "' @@ -204,8 +208,8 @@ def test_run_finished_event(self) -> None: assert update is not None assert update.role == "assistant" assert update.finish_reason == "stop" - assert update.additional_properties["thread_id"] == "thread_123" - assert update.additional_properties["run_id"] == "run_456" + assert update.additional_properties["thread_id"] == "thread_123" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert update.additional_properties["run_id"] == "run_456" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_run_finished_event_with_interrupt(self) -> None: """RUN_FINISHED interrupt metadata is preserved in additional_properties.""" @@ -224,6 +228,7 @@ def test_run_finished_event_with_interrupt(self) -> None: update = converter.convert_event(event) assert update is not None + assert update.additional_properties is not None assert update.additional_properties["interrupt"] == [{"id": "req_1", "value": {"question": "Continue?"}}] assert update.additional_properties["result"] == {"status": "paused"} @@ -271,6 +276,7 @@ def test_custom_event_conversion(self) -> None: update = converter.convert_event(event) assert update is not None + assert update.additional_properties is not None assert update.additional_properties["ag_ui_custom_event"]["name"] == "progress" assert update.additional_properties["ag_ui_custom_event"]["value"] == {"percent": 10} assert update.additional_properties["ag_ui_custom_event"]["raw_type"] == "CUSTOM" @@ -283,7 +289,7 @@ def test_custom_event_alias_conversion(self) -> None: {"type": "custom_event", "name": "alias_lower", "value": {"v": 2}}, ] - updates = [converter.convert_event(event) for event in events] + updates = cast(list[Any], [converter.convert_event(event) for event in events]) assert updates[0] is not None assert updates[1] is not None @@ -310,7 +316,7 @@ def test_full_conversation_flow(self) -> None: {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, ] - updates = [converter.convert_event(event) for event in events] + updates = cast(list[Any], [converter.convert_event(event) for event in events]) non_none_updates = [u for u in updates if u is not None] assert len(non_none_updates) == 10 @@ -330,7 +336,7 @@ def test_multiple_tool_calls(self) -> None: {"type": "TOOL_CALL_END", "toolCallId": "call_2"}, ] - updates = [converter.convert_event(event) for event in events] + updates = cast(list[Any], [converter.convert_event(event) for event in events]) non_none_updates = [u for u in updates if u is not None] assert len(non_none_updates) == 4 diff --git a/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py index 5a86a6ff59e..115bb46b5d6 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py +++ b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py @@ -11,11 +11,13 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content, WorkflowBuilder, WorkflowContext, executor -from conftest import StubAgent +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from fastapi import FastAPI from fastapi.testclient import TestClient -from sse_helpers import parse_sse_response, parse_sse_to_event_stream -from typing_extensions import Never +from sse_helpers import ( # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] + parse_sse_response, + parse_sse_to_event_stream, +) from agent_framework_ag_ui import AgentFrameworkAgent, AgentFrameworkWorkflow, add_agent_framework_fastapi_endpoint @@ -168,7 +170,7 @@ def test_workflow_sse_round_trip() -> None: """Workflow events survive SSE encoding/parsing.""" @executor(id="greeter") - async def greeter(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def greeter(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.yield_output("Hello from workflow!") app = _build_app_with_workflow(WorkflowBuilder(start_executor=greeter)) diff --git a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py index 69f7b7bdb35..75d0e1b4da2 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py +++ b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py @@ -5,6 +5,7 @@ import base64 import json import logging +from typing import Any import pytest from agent_framework import Content, Message @@ -108,7 +109,7 @@ def test_agui_tool_approval_updates_tool_call_arguments(): The raw messages (for MESSAGES_SNAPSHOT) should contain all steps with status, so the UI can show which steps were enabled/disabled. """ - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "role": "assistant", "content": "", @@ -171,6 +172,7 @@ def test_agui_tool_approval_updates_tool_call_arguments(): approval_content = next( content for content in approval_msg.contents if content.type == "function_approval_response" ) + assert approval_content.function_call is not None assert approval_content.function_call.parse_arguments() == { "steps": [ {"description": "Boil water", "status": "enabled"}, @@ -189,7 +191,7 @@ def test_agui_tool_approval_updates_tool_call_arguments(): def test_agui_tool_approval_from_confirm_changes_maps_to_function_call(): """Confirm_changes approvals map back to the original tool call when metadata is present.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "role": "assistant", "content": "", @@ -224,15 +226,18 @@ def test_agui_tool_approval_from_confirm_changes_maps_to_function_call(): content for content in approval_msg.contents if content.type == "function_approval_response" ) + assert approval_content.function_call is not None assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call is not None assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call is not None assert approval_content.function_call.parse_arguments() == {} assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call(): """Confirm_changes approvals map to the only sibling tool call when metadata is missing.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "role": "assistant", "content": "", @@ -269,15 +274,18 @@ def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call(): content for content in approval_msg.contents if content.type == "function_approval_response" ) + assert approval_content.function_call is not None assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call is not None assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call is not None assert approval_content.function_call.parse_arguments() == {} assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call(): """Approval tool payloads map to the referenced function call when function_call_id is present.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "role": "assistant", "content": "", @@ -322,14 +330,17 @@ def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call(): content for content in approval_msg.contents if content.type == "function_approval_response" ) + assert approval_content.function_call is not None assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call is not None assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call is not None assert approval_content.function_call.parse_arguments() == {} def test_agui_multiple_messages_to_agent_framework(): """Test converting multiple AG-UI messages.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ {"role": "user", "content": "First message", "id": "msg-1"}, {"role": "assistant", "content": "Second message", "id": "msg-2"}, {"role": "user", "content": "Third message", "id": "msg-3"}, @@ -382,7 +393,9 @@ def test_agui_function_approvals(): assert msg.contents[0].type == "function_approval_response" assert msg.contents[0].approved is True assert msg.contents[0].id == "approval-1" + assert msg.contents[0].function_call is not None assert msg.contents[0].function_call.name == "search" + assert msg.contents[0].function_call is not None assert msg.contents[0].function_call.call_id == "call-1" assert msg.contents[1].type == "function_approval_response" @@ -405,7 +418,7 @@ def test_agui_non_string_content(): assert len(messages) == 1 assert len(messages[0].contents) == 1 assert messages[0].contents[0].type == "text" - assert "nested" in messages[0].contents[0].text + assert "nested" in messages[0].contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_agui_multimodal_legacy_binary_to_agent_framework(): @@ -907,7 +920,7 @@ def test_agui_to_agent_framework_tool_result(): }, ] - result = agui_messages_to_agent_framework(messages) + result = agui_messages_to_agent_framework(messages) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert len(result) == 2 # Second message should be tool result @@ -1305,7 +1318,7 @@ def test_convert_agui_content_unknown_part_type_without_text(): result = _convert_agui_content_to_framework([{"type": "widget", "data": 42}]) assert len(result) == 1 - assert "widget" in result[0].text + assert "widget" in result[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_convert_agui_content_none(): @@ -1706,7 +1719,7 @@ def test_agui_fresh_approval_is_still_processed(): On Turn 2, the approval is fresh (no subsequent assistant message), so it must be processed normally to execute the tool. """ - messages_input = [ + messages_input: list[dict[str, Any]] = [ # Turn 1: user asks something {"role": "user", "content": "What time is it?", "id": "msg_1"}, # Turn 1: assistant calls a tool @@ -1739,6 +1752,7 @@ def test_agui_fresh_approval_is_still_processed(): ] assert len(approval_contents) == 1, "Fresh approval should produce function_approval_response" assert approval_contents[0].approved is True + assert approval_contents[0].function_call is not None assert approval_contents[0].function_call.name == "get_datetime" @@ -1747,7 +1761,7 @@ class TestReasoningRoundTrip: def test_reasoning_skipped_on_inbound(self): """Reasoning messages from prior snapshot are not forwarded to the LLM.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ {"id": "u1", "role": "user", "content": "Hello"}, {"id": "r1", "role": "reasoning", "content": "Thinking..."}, {"id": "a1", "role": "assistant", "content": "Hi there"}, @@ -1761,7 +1775,7 @@ def test_reasoning_skipped_on_inbound(self): def test_reasoning_preserved_in_snapshot_format(self): """Reasoning messages retain their role through snapshot normalization.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ {"id": "u1", "role": "user", "content": "Hello"}, {"id": "r1", "role": "reasoning", "content": "Thinking about this..."}, {"id": "a1", "role": "assistant", "content": "Answer"}, @@ -1775,7 +1789,7 @@ def test_reasoning_preserved_in_snapshot_format(self): def test_reasoning_with_encrypted_value_in_snapshot_format(self): """Reasoning with encryptedValue passes through snapshot normalization.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "id": "r1", "role": "reasoning", @@ -1792,7 +1806,7 @@ def test_reasoning_with_encrypted_value_in_snapshot_format(self): def test_reasoning_encrypted_value_snake_case_normalized(self): """Snake-case encrypted_value is normalized to encryptedValue in snapshot format.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ { "id": "r1", "role": "reasoning", @@ -1809,7 +1823,7 @@ def test_reasoning_encrypted_value_snake_case_normalized(self): def test_multi_turn_with_reasoning_in_prior_snapshot(self): """Second turn with reasoning from prior snapshot does not corrupt messages.""" - messages_input = [ + messages_input: list[dict[str, Any]] = [ {"id": "u1", "role": "user", "content": "First question"}, {"id": "r1", "role": "reasoning", "content": "Prior reasoning"}, {"id": "a1", "role": "assistant", "content": "First answer"}, @@ -1841,7 +1855,7 @@ def test_parse_multimodal_media_part_base64_value_field(): {"type": "image", "source": {"type": "base64", "value": "aGVsbG8=", "mimeType": "image/png"}} ) assert result is not None - assert "aGVsbG8=" in result.uri + assert "aGVsbG8=" in result.uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_multimodal_media_part_data_source_value_field(): @@ -1852,7 +1866,7 @@ def test_parse_multimodal_media_part_data_source_value_field(): {"type": "image", "source": {"type": "data", "value": "aGVsbG8=", "mimeType": "image/png"}} ) assert result is not None - assert "aGVsbG8=" in result.uri + assert "aGVsbG8=" in result.uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_multimodal_media_part_base64_data_field_backward_compat(): @@ -1863,7 +1877,7 @@ def test_parse_multimodal_media_part_base64_data_field_backward_compat(): {"type": "image", "source": {"type": "base64", "data": "aGVsbG8=", "mimeType": "image/png"}} ) assert result is not None - assert "aGVsbG8=" in result.uri + assert "aGVsbG8=" in result.uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_multimodal_media_part_value_preferred_over_data(): @@ -1883,7 +1897,7 @@ def test_parse_multimodal_media_part_value_preferred_over_data(): ) assert result is not None # 'value' field content should be used (base64 of "value") - assert "dmFsdWU=" in result.uri + assert "dmFsdWU=" in result.uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_multimodal_media_part_unknown_source_value_fallback(): @@ -1894,4 +1908,4 @@ def test_parse_multimodal_media_part_unknown_source_value_fallback(): {"type": "image", "source": {"type": "custom", "value": "aGVsbG8=", "mimeType": "image/png"}} ) assert result is not None - assert "aGVsbG8=" in result.uri + assert "aGVsbG8=" in result.uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] diff --git a/python/packages/ag-ui/tests/ag_ui/test_message_hygiene.py b/python/packages/ag-ui/tests/ag_ui/test_message_hygiene.py index a3ccf26d1ac..fba1b91fb09 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/ag_ui/test_message_hygiene.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from typing import Any + from agent_framework import Content, Message from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history @@ -276,7 +278,7 @@ def test_clean_resolved_approvals_from_snapshot() -> None: from agent_framework_ag_ui._agent_run import _clean_resolved_approvals_from_snapshot # Snapshot still has the approval payload - snapshot_messages = [ + snapshot_messages: list[dict[str, Any]] = [ # type: ignore[name-defined] {"role": "user", "content": "What time is it?", "id": "msg_1"}, { "role": "assistant", @@ -318,7 +320,7 @@ def test_clean_resolved_approvals_from_snapshot_no_approvals() -> None: """When there are no approval payloads, snapshot should be unchanged.""" from agent_framework_ag_ui._agent_run import _clean_resolved_approvals_from_snapshot # type: ignore - snapshot_messages = [ + snapshot_messages: list[dict[str, Any]] = [ # type: ignore[name-defined] {"role": "user", "content": "Hello", "id": "msg_1"}, {"role": "assistant", "content": "Hi there", "id": "msg_2"}, ] @@ -349,7 +351,7 @@ def test_cleaned_snapshot_prevents_approval_reprocessing() -> None: from agent_framework_ag_ui._message_adapters import normalize_agui_input_messages # Turn 2 snapshot: still has the raw approval payload - snapshot_messages = [ + snapshot_messages: list[dict[str, Any]] = [ # type: ignore[name-defined] {"role": "user", "content": "What time is it?", "id": "msg_1"}, { "role": "assistant", @@ -387,7 +389,7 @@ def test_cleaned_snapshot_prevents_approval_reprocessing() -> None: assert snapshot_messages[2]["content"] == "2024-01-01 12:00:00" # Simulate Turn 3: CopilotKit re-sends the cleaned snapshot + new messages - turn3_messages = list(snapshot_messages) + [ + turn3_messages: list[dict[str, Any]] = list(snapshot_messages) + [ # type: ignore[name-defined] {"role": "assistant", "content": "It is 12:00 PM.", "id": "msg_4"}, {"role": "user", "content": "Thanks!", "id": "msg_5"}, ] diff --git a/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py b/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py index 714ce2ce500..c12c59a7e57 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py +++ b/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py @@ -13,10 +13,13 @@ from typing import Any from agent_framework import AgentResponseUpdate, Content -from conftest import StubAgent +from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from fastapi import FastAPI from fastapi.testclient import TestClient -from sse_helpers import parse_sse_response, parse_sse_to_event_stream +from sse_helpers import ( # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] + parse_sse_response, + parse_sse_to_event_stream, +) from agent_framework_ag_ui import AgentFrameworkAgent, add_agent_framework_fastapi_endpoint @@ -152,7 +155,7 @@ async def test_approval_interrupt_resume_round_trip() -> None: The confirm_changes flow uses a specific message format that bypasses the agent and directly emits a confirmation text message. """ - from event_stream import EventStream + from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] steps = [{"description": "Execute task", "status": "enabled"}] @@ -268,7 +271,7 @@ async def test_approval_interrupt_resume_round_trip() -> None: async def test_workflow_interrupt_resume_round_trip() -> None: """Turn 1: workflow request_info → interrupt. Turn 2: resume → completion.""" - from event_stream import EventStream + from event_stream import EventStream # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui_examples.agents.subgraphs_agent import subgraphs_agent diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index a901b61e6e9..11fad20be35 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -2,6 +2,8 @@ """Tests for _agent_run.py helper functions and FlowState.""" +from typing import cast + import pytest from ag_ui.core import ( CustomEvent, @@ -44,6 +46,12 @@ ) +def _message_role(message: object) -> object: + if isinstance(message, dict): + return cast(dict[str, object], message).get("role") + return getattr(message, "role", None) + + class TestBuildSafeMetadata: """Tests for _build_safe_metadata function.""" @@ -276,8 +284,8 @@ def test_creates_message(self): assert result is not None assert result.role == "system" assert len(result.contents) == 1 - assert "Hello world" in result.contents[0].text - assert "Current state" in result.contents[0].text + assert "Hello world" in result.contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "Current state" in result.contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] class TestInjectStateContext: @@ -321,13 +329,13 @@ def test_injects_before_last_user_message(self): assert len(result) == 3 # System message first assert result[0].role == "system" - assert "helpful" in result[0].contents[0].text + assert "helpful" in result[0].contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # State context second assert result[1].role == "system" - assert "Current state" in result[1].contents[0].text + assert "Current state" in result[1].contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # User message last assert result[2].role == "user" - assert "Hello" in result[2].contents[0].text + assert "Hello" in result[2].contents[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Additional tests for _agent_run.py functions @@ -517,9 +525,9 @@ def test_emit_tool_result_serializes_non_string_result(): events = _emit_tool_result(content, flow, predictive_handler=None) result_event = next(event for event in events if getattr(event, "type", None) == "TOOL_CALL_RESULT") - assert isinstance(result_event.content, str) - assert '"ok": true' in result_event.content - assert flow.tool_results[0]["content"] == result_event.content + assert isinstance(result_event.content, str) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert '"ok": true' in result_event.content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert flow.tool_results[0]["content"] == result_event.content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_emit_content_usage_emits_custom_usage_event(): @@ -531,8 +539,8 @@ def test_emit_content_usage_emits_custom_usage_event(): assert len(events) == 1 assert events[0].type == "CUSTOM" - assert events[0].name == "usage" - assert events[0].value["total_token_count"] == 5 + assert events[0].name == "usage" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert events[0].value["total_token_count"] == 5 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_emit_approval_request_populates_interrupt_metadata(): @@ -772,6 +780,7 @@ def test_malformed_json_in_confirm_args_skips_confirmation(): valid_arguments = '{"content": "hello"}' tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}} should_skip_confirmation = False + function_arguments: dict[str, object] | None = None try: function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}")) except json.JSONDecodeError: @@ -915,7 +924,7 @@ async def test_run_agent_stream_accumulates_multiple_confirm_interrupts(): """ import json - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent @@ -1026,11 +1035,11 @@ def test_produces_start_and_args_events(self): assert len(events) == 2 assert events[0].type == "TOOL_CALL_START" - assert events[0].tool_call_id == "mcp_call_1" - assert events[0].tool_call_name == "search" + assert events[0].tool_call_id == "mcp_call_1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert events[0].tool_call_name == "search" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert events[1].type == "TOOL_CALL_ARGS" - assert events[1].tool_call_id == "mcp_call_1" - assert "weather" in events[1].delta + assert events[1].tool_call_id == "mcp_call_1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert "weather" in events[1].delta # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_tracks_in_flow_state(self): """MCP tool call is tracked in flow.pending_tool_calls and tool_calls_by_id.""" @@ -1059,7 +1068,7 @@ def test_no_server_name_uses_tool_name_only(self): events = _emit_mcp_tool_call(content, flow) - assert events[0].tool_call_name == "list_files" + assert events[0].tool_call_name == "list_files" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_no_arguments_skips_args_event(self): """No arguments produces only ToolCallStart, no ToolCallArgs.""" @@ -1082,9 +1091,9 @@ def test_generates_id_when_missing(self): events = _emit_mcp_tool_call(content, flow) assert len(events) >= 1 - assert events[0].tool_call_id is not None - assert events[0].tool_call_id != "" - assert events[0].tool_call_name == "test_tool" + assert events[0].tool_call_id is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert events[0].tool_call_id != "" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert events[0].tool_call_name == "test_tool" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_missing_tool_name_falls_back_to_mcp_tool(self): """When tool_name is None, the fallback 'mcp_tool' is used.""" @@ -1094,7 +1103,7 @@ def test_missing_tool_name_falls_back_to_mcp_tool(self): events = _emit_mcp_tool_call(content, flow) assert len(events) >= 1 - assert events[0].tool_call_name == "mcp_tool" + assert events[0].tool_call_name == "mcp_tool" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] class TestEmitMcpToolResult: @@ -1112,10 +1121,10 @@ def test_produces_end_and_result_events(self): assert len(events) == 2 assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "mcp_call_1" + assert events[0].tool_call_id == "mcp_call_1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert events[1].type == "TOOL_CALL_RESULT" - assert events[1].tool_call_id == "mcp_call_1" - assert "Weather" in events[1].content + assert events[1].tool_call_id == "mcp_call_1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert "Weather" in events[1].content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_tracks_in_flow_state(self): """MCP tool result is tracked in flow.tool_results and tool_calls_ended.""" @@ -1152,8 +1161,8 @@ def test_serializes_non_string_output(self): events = _emit_mcp_tool_result(content, flow) result_event = events[1] - assert isinstance(result_event.content, str) - assert '"key": "value"' in result_event.content + assert isinstance(result_event.content, str) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert '"key": "value"' in result_event.content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_output_none_falls_back_to_empty_string(self): """When output is None (default), the result content is an empty string.""" @@ -1164,7 +1173,7 @@ def test_output_none_falls_back_to_empty_string(self): assert len(events) == 2 assert events[1].type == "TOOL_CALL_RESULT" - assert events[1].content == "" + assert events[1].content == "" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_resets_flow_state_like_emit_tool_result(self): """MCP tool result performs same FlowState cleanup as _emit_tool_result.""" @@ -1301,10 +1310,10 @@ def test_generates_message_id_when_missing(self): events = _emit_text_reasoning(content) assert len(events) == 5 - assert events[0].message_id is not None - assert events[0].message_id != "" + assert events[0].message_id is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert events[0].message_id != "" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # All events share the same message_id - assert events[1].message_id == events[0].message_id + assert events[1].message_id == events[0].message_id # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] class TestEmitContentMcpRouting: @@ -1323,7 +1332,7 @@ def test_routes_mcp_server_tool_call(self): assert len(events) >= 1 assert events[0].type == "TOOL_CALL_START" - assert events[0].tool_call_name == "test_tool" + assert events[0].tool_call_name == "test_tool" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_routes_mcp_server_tool_result(self): """_emit_content dispatches mcp_server_tool_result to _emit_mcp_tool_result.""" @@ -1398,7 +1407,7 @@ def test_snapshot_includes_reasoning(self): snapshot = _build_messages_snapshot(flow, []) - roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages] + roles = [_message_role(m) for m in snapshot.messages] assert "reasoning" in roles def test_snapshot_preserves_reasoning_encrypted_value(self): @@ -1418,11 +1427,7 @@ def test_snapshot_preserves_reasoning_encrypted_value(self): snapshot = _build_messages_snapshot(flow, []) - reasoning_msgs = [ - m - for m in snapshot.messages - if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning" - ] + reasoning_msgs = [m for m in snapshot.messages if _message_role(m) == "reasoning"] assert len(reasoning_msgs) == 1 msg = reasoning_msgs[0] if isinstance(msg, dict): @@ -1463,7 +1468,7 @@ def test_snapshot_reasoning_ordering(self): # user -> assistant text -> reasoning assert len(snapshot.messages) == 3 - roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages] + roles = [_message_role(m) for m in snapshot.messages] assert roles == ["user", "assistant", "reasoning"] def test_reasoning_accumulates_incremental_deltas(self): @@ -1629,7 +1634,7 @@ def test_reasoning_distinct_ids_close_previous_block(self): close = _close_reasoning_block(flow) # events1: Start(block1) + MsgStart(block1) + Content(block1) - assert events1[0].message_id == "block1" + assert events1[0].message_id == "block1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # events2: MsgEnd(block1) + End(block1) + Start(block2) + MsgStart(block2) + Content(block2) assert isinstance(events2[0], ReasoningMessageEndEvent) assert events2[0].message_id == "block1" @@ -1675,7 +1680,7 @@ def test_reasoning_role_with_flow(self): async def test_session_id_matches_thread_id(): """Session created by run_agent_stream uses the client thread_id as session_id.""" - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent @@ -1696,7 +1701,7 @@ async def test_session_id_matches_thread_id(): async def test_session_id_matches_camel_case_thread_id(): """Session uses threadId (camelCase) as session_id when snake_case is absent.""" - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent @@ -1717,7 +1722,7 @@ async def test_session_id_matches_camel_case_thread_id(): async def test_session_id_matches_thread_id_with_service_session(): """Session uses thread_id as session_id even when use_service_session is enabled.""" - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent @@ -1741,7 +1746,7 @@ async def test_session_id_generated_when_no_thread_id(): """Session gets a generated UUID as session_id when no thread_id is provided.""" import uuid - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent @@ -1764,7 +1769,7 @@ async def test_service_session_no_thread_id_generates_uuid(): """With use_service_session=True and no thread_id, session_id is a UUID and service_session_id is None.""" import uuid - from conftest import StubAgent + from conftest import StubAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_ag_ui import AgentFrameworkAgent diff --git a/python/packages/ag-ui/tests/ag_ui/test_run_common.py b/python/packages/ag-ui/tests/ag_ui/test_run_common.py index 16ee1520332..7571f87abb1 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run_common.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run_common.py @@ -2,6 +2,7 @@ """Tests for _run_common.py edge cases.""" +import pytest from ag_ui.core import EventType from agent_framework import Content @@ -109,7 +110,7 @@ class TestEmitToolResult: def test_tool_result_without_call_id_returns_empty(self): """Tool result Content without call_id returns empty event list.""" - content = Content.from_function_result(call_id=None, result="some result") + content = Content.from_function_result(call_id=None, result="some result") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] flow = FlowState() events = _emit_tool_result(content, flow) assert events == [] @@ -157,10 +158,9 @@ def test_empty_text_is_allowed(self): def test_non_mapping_state_raises(self): """Passing a non-mapping value for state raises TypeError.""" - import pytest with pytest.raises(TypeError): - state_update(text="t", state=["not", "a", "mapping"]) # type: ignore[arg-type] + state_update(text="t", state=["not", "a", "mapping"]) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def test_state_is_copied_defensively(self): """Mutating the caller's dict after ``state_update`` must not mutate the content.""" @@ -245,7 +245,7 @@ def test_emits_state_snapshot_after_tool_call_result(self): assert event_types[1] == EventType.TOOL_CALL_RESULT state_idx = event_types.index(EventType.STATE_SNAPSHOT) assert state_idx == 2 - assert events[state_idx].snapshot == {"weather": {"temp": 14, "conditions": "foggy"}} + assert events[state_idx].snapshot == {"weather": {"temp": 14, "conditions": "foggy"}} # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_updates_flow_current_state(self): tool_return = state_update(text="", state={"a": 1}) @@ -283,8 +283,8 @@ def test_tool_result_content_text_unchanged(self): events = _emit_tool_result(content, flow) result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] assert len(result_events) == 1 - assert result_events[0].content == "Weather: 14°C" - assert TOOL_RESULT_STATE_KEY not in result_events[0].content + assert result_events[0].content == "Weather: 14°C" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TOOL_RESULT_STATE_KEY not in result_events[0].content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_display_payload_routes_to_ui_only(self): """A display marker overrides only the UI event, not the LLM-bound tool result.""" @@ -299,9 +299,9 @@ def test_display_payload_routes_to_ui_only(self): result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] assert len(result_events) == 1 - assert result_events[0].content == '{"temp": 14, "conditions": "foggy"}' + assert result_events[0].content == '{"temp": 14, "conditions": "foggy"}' # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert flow.tool_results[-1]["content"] == "Weather: 14°C" - assert TOOL_RESULT_DISPLAY_KEY not in result_events[0].content + assert TOOL_RESULT_DISPLAY_KEY not in result_events[0].content # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert TOOL_RESULT_DISPLAY_KEY not in flow.tool_results[-1]["content"] def test_plain_tool_result_uses_existing_content_for_both_channels(self): @@ -313,7 +313,7 @@ def test_plain_tool_result_uses_existing_content_for_both_channels(self): result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] assert len(result_events) == 1 - assert result_events[0].content == "plain result" + assert result_events[0].content == "plain result" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert flow.tool_results[-1]["content"] == "plain result" def test_display_only_payload_falls_back_to_llm_content(self): @@ -325,7 +325,7 @@ def test_display_only_payload_falls_back_to_llm_content(self): events = _emit_tool_result(content, flow) result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] - assert result_events[0].content == '{"temp": 14}' + assert result_events[0].content == '{"temp": 14}' # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert flow.tool_results[-1]["content"] == '{"temp": 14}' def test_pre_serialized_display_string_routes_verbatim(self): @@ -337,7 +337,7 @@ def test_pre_serialized_display_string_routes_verbatim(self): events = _emit_tool_result(content, flow) result_events = [e for e in events if e.type == EventType.TOOL_CALL_RESULT] - assert result_events[0].content == '{"temp":14}' + assert result_events[0].content == '{"temp":14}' # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert flow.tool_results[-1]["content"] == "Weather summary" def test_coexists_with_active_predictive_state_handler(self): @@ -362,8 +362,8 @@ def test_coexists_with_active_predictive_state_handler(self): # Exactly one coalesced snapshot must be emitted containing all merged keys. snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT] assert len(snapshots) == 1 - assert snapshots[0].snapshot["draft_final"] is True - assert snapshots[0].snapshot["preexisting"] == "value" + assert snapshots[0].snapshot["draft_final"] is True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert snapshots[0].snapshot["preexisting"] == "value" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert flow.current_state["draft_final"] is True assert flow.current_state["preexisting"] == "value" @@ -382,7 +382,7 @@ def test_predictive_and_deterministic_emit_single_snapshot(self): snapshots = [e for e in events if e.type == EventType.STATE_SNAPSHOT] assert len(snapshots) == 1, f"Expected 1 coalesced snapshot, got {len(snapshots)}" - assert snapshots[0].snapshot == {"existing": "yes", "new_key": 42} + assert snapshots[0].snapshot == {"existing": "yes", "new_key": 42} # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] class TestEmitMcpToolResultWithState: @@ -446,6 +446,6 @@ def test_mcp_tool_result_routes_display_payload_to_ui_only(self): assert len(result_events) == 1 # UI event carries the structured display payload. - assert _json.loads(result_events[0].content) == display_payload + assert _json.loads(result_events[0].content) == display_payload # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # LLM-side accumulator keeps the short text. assert flow.tool_results[-1]["content"] == "2 rows returned" diff --git a/python/packages/ag-ui/tests/ag_ui/test_structured_output.py b/python/packages/ag-ui/tests/ag_ui/test_structured_output.py index 39e8b1779cf..df612ef68c2 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_structured_output.py +++ b/python/packages/ag-ui/tests/ag_ui/test_structured_output.py @@ -42,7 +42,7 @@ async def stream_fn( ) agent = Agent(name="test", instructions="Test", client=streaming_chat_client_stub(stream_fn)) - agent.default_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = {"response_format": RecipeOutput} wrapper = AgentFrameworkAgent( agent=agent, @@ -84,7 +84,7 @@ async def stream_fn( yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(steps_data))]) agent = Agent(name="test", instructions="Test", client=streaming_chat_client_stub(stream_fn)) - agent.default_options = ChatOptions(response_format=StepsOutput) + agent.default_options = {"response_format": StepsOutput} wrapper = AgentFrameworkAgent( agent=agent, @@ -119,7 +119,7 @@ async def test_structured_output_with_no_schema_match(streaming_chat_client_stub agent = Agent( name="test", instructions="Test", client=streaming_chat_client_stub(stream_from_updates_fixture(updates)) ) - agent.default_options = ChatOptions(response_format=GenericOutput) + agent.default_options = {"response_format": GenericOutput} wrapper = AgentFrameworkAgent( agent=agent, @@ -154,7 +154,7 @@ async def stream_fn( yield ChatResponseUpdate(contents=[Content.from_text(text='{"data": {"key": "value"}, "info": "processed"}')]) agent = Agent(name="test", instructions="Test", client=streaming_chat_client_stub(stream_fn)) - agent.default_options = ChatOptions(response_format=DataOutput) + agent.default_options = {"response_format": DataOutput} wrapper = AgentFrameworkAgent( agent=agent, @@ -214,7 +214,7 @@ async def stream_fn( yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(output_data))]) agent = Agent(name="test", instructions="Test", client=streaming_chat_client_stub(stream_fn)) - agent.default_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = {"response_format": RecipeOutput} wrapper = AgentFrameworkAgent( agent=agent, @@ -249,7 +249,7 @@ async def stream_fn( yield ChatResponseUpdate(contents=[]) agent = Agent(name="test", instructions="Test", client=streaming_chat_client_stub(stream_fn)) - agent.default_options = ChatOptions(response_format=RecipeOutput) + agent.default_options = {"response_format": RecipeOutput} wrapper = AgentFrameworkAgent(agent=agent) diff --git a/python/packages/ag-ui/tests/ag_ui/test_tooling.py b/python/packages/ag-ui/tests/ag_ui/test_tooling.py index 890ae445415..26fc57cd650 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_tooling.py +++ b/python/packages/ag-ui/tests/ag_ui/test_tooling.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +from typing import Any, cast from unittest.mock import MagicMock import pytest @@ -61,7 +62,7 @@ def test_register_additional_client_tools_assigns_when_configured() -> None: agent = Agent(client=mock_chat_client) tools = [DummyTool("x")] - register_additional_client_tools(agent, tools) + register_additional_client_tools(cast(Any, agent), tools) assert mock_chat_client.function_invocation_configuration["additional_tools"] == tools @@ -73,7 +74,7 @@ def test_collect_server_tools_includes_mcp_tools_when_connected() -> None: mock_mcp = MockMCPTool([mcp_function1, mcp_function2], is_connected=True) agent = _create_chat_agent_with_tool("regular_tool") - agent.mcp_tools = [mock_mcp] + agent.mcp_tools = [cast(Any, mock_mcp)] tools = collect_server_tools(agent) @@ -90,7 +91,7 @@ def test_collect_server_tools_excludes_mcp_tools_when_not_connected() -> None: mock_mcp = MockMCPTool([mcp_function], is_connected=False) agent = _create_chat_agent_with_tool("regular_tool") - agent.mcp_tools = [mock_mcp] + agent.mcp_tools = [cast(Any, mock_mcp)] tools = collect_server_tools(agent) @@ -117,7 +118,7 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: mock_mcp = MockMCPTool([mcp_function], is_connected=True) agent = _create_chat_agent_with_tool("regular_tool") - agent.mcp_tools = [mock_mcp] + agent.mcp_tools = [cast(Any, mock_mcp)] # Verify the public property works assert agent.mcp_tools == [mock_mcp] @@ -135,7 +136,7 @@ def test_collect_server_tools_raises_on_duplicate_agent_and_mcp_tool_names() -> mock_mcp = MockMCPTool([duplicate_tool], is_connected=True, name="docs-mcp") agent = _create_chat_agent_with_tool("regular_tool") - agent.mcp_tools = [mock_mcp] + agent.mcp_tools = [cast(Any, mock_mcp)] with pytest.raises(ValueError, match="Duplicate tool name 'regular_tool'"): collect_server_tools(agent) @@ -151,7 +152,7 @@ class MockAgent: pass agent = MockAgent() - tools = collect_server_tools(agent) + tools = collect_server_tools(cast(Any, agent)) assert tools == [] @@ -175,7 +176,7 @@ class MockAgent: tools = [DummyTool("x")] # Should not raise - register_additional_client_tools(agent, tools) + register_additional_client_tools(cast(Any, agent), tools) def test_merge_tools_no_client_tools() -> None: diff --git a/python/packages/ag-ui/tests/ag_ui/test_types.py b/python/packages/ag-ui/tests/ag_ui/test_types.py index b0117ca6cd7..65615a41ea5 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_types.py +++ b/python/packages/ag-ui/tests/ag_ui/test_types.py @@ -150,7 +150,7 @@ class TestAGUIRequest: def test_agui_request_minimal(self) -> None: """Test creating AGUIRequest with only required fields.""" - request = AGUIRequest(messages=[{"role": "user", "content": "Hello"}]) + request = AGUIRequest.model_validate({"messages": [{"role": "user", "content": "Hello"}]}) assert len(request.messages) == 1 assert request.messages[0]["content"] == "Hello" @@ -164,15 +164,17 @@ def test_agui_request_minimal(self) -> None: def test_agui_request_all_fields(self) -> None: """Test creating AGUIRequest with all fields populated.""" - request = AGUIRequest( - messages=[{"role": "user", "content": "Hello"}], - run_id="run-123", - thread_id="thread-456", - state={"counter": 0}, - tools=[{"name": "search", "description": "Search tool"}], - context=[{"type": "document", "content": "Some context"}], - forwarded_props={"custom_key": "custom_value"}, - parent_run_id="parent-run-789", + request = AGUIRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "run_id": "run-123", + "thread_id": "thread-456", + "state": {"counter": 0}, + "tools": [{"name": "search", "description": "Search tool"}], + "context": [{"type": "document", "content": "Some context"}], + "forwarded_props": {"custom_key": "custom_value"}, + "parent_run_id": "parent-run-789", + } ) assert request.run_id == "run-123" @@ -185,12 +187,14 @@ def test_agui_request_all_fields(self) -> None: def test_agui_request_camel_case_aliases(self) -> None: """Test AGUIRequest accepts camelCase aliases from AG-UI HTTP clients.""" - request = AGUIRequest( - messages=[{"role": "user", "content": "Hello"}], - runId="run-camel-1", - threadId="thread-camel-1", - forwardedProps={"k": "v"}, - parentRunId="parent-camel-1", + request = AGUIRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "runId": "run-camel-1", + "threadId": "thread-camel-1", + "forwardedProps": {"k": "v"}, + "parentRunId": "parent-camel-1", + } ) assert request.run_id == "run-camel-1" @@ -200,10 +204,12 @@ def test_agui_request_camel_case_aliases(self) -> None: def test_agui_request_model_dump_excludes_none(self) -> None: """Test that model_dump(exclude_none=True) excludes None fields.""" - request = AGUIRequest( - messages=[{"role": "user", "content": "test"}], - tools=[{"name": "my_tool"}], - context=[{"id": "ctx1"}], + request = AGUIRequest.model_validate( + { + "messages": [{"role": "user", "content": "test"}], + "tools": [{"name": "my_tool"}], + "context": [{"id": "ctx1"}], + } ) dumped = request.model_dump(exclude_none=True) @@ -223,12 +229,14 @@ def test_agui_request_model_dump_includes_all_set_fields(self) -> None: This is critical for the fix - ensuring tools, context, forwarded_props, and parent_run_id are not stripped during request validation. """ - request = AGUIRequest( - messages=[{"role": "user", "content": "test"}], - tools=[{"name": "client_tool", "parameters": {"type": "object"}}], - context=[{"type": "snippet", "content": "code here"}], - forwarded_props={"auth_token": "secret", "user_id": "user-1"}, - parent_run_id="parent-456", + request = AGUIRequest.model_validate( + { + "messages": [{"role": "user", "content": "test"}], + "tools": [{"name": "client_tool", "parameters": {"type": "object"}}], + "context": [{"type": "snippet", "content": "code here"}], + "forwarded_props": {"auth_token": "secret", "user_id": "user-1"}, + "parent_run_id": "parent-456", + } ) dumped = request.model_dump(exclude_none=True) @@ -241,9 +249,11 @@ def test_agui_request_model_dump_includes_all_set_fields(self) -> None: def test_agui_request_available_interrupts_alias_round_trip(self) -> None: """availableInterrupts should deserialize, while dumps remain snake_case.""" - request = AGUIRequest( - messages=[{"role": "user", "content": "Hello"}], - availableInterrupts=[{"id": "req_1", "value": {"choice": "A"}}], + request = AGUIRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "availableInterrupts": [{"id": "req_1", "value": {"choice": "A"}}], + } ) assert request.available_interrupts == [{"id": "req_1", "value": {"choice": "A"}}] diff --git a/python/packages/ag-ui/tests/ag_ui/test_utils.py b/python/packages/ag-ui/tests/ag_ui/test_utils.py index f353d2f0a72..57f97de8159 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_utils.py +++ b/python/packages/ag-ui/tests/ag_ui/test_utils.py @@ -64,7 +64,7 @@ def test_merge_state_deep_copy(): result["recipe"]["ingredients"].append("eggs") - assert "eggs" not in current["recipe"]["ingredients"] + assert "eggs" not in current["recipe"]["ingredients"] # type: ignore[operator] # pyrefly: ignore[not-iterable] assert current["recipe"]["ingredients"] == ["flour", "sugar"] assert result["recipe"]["ingredients"] == ["flour", "sugar", "eggs"] diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py index 80bd21b2fa9..858d10370f0 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_agent.py @@ -22,7 +22,7 @@ async def test_workflow_wrapper_rejects_workflow_and_factory_at_once() -> None: @executor(id="start") async def start(message: Any, ctx: WorkflowContext) -> None: del message - await ctx.yield_output("ok") + await ctx.yield_output("ok") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=start).build() with pytest.raises(ValueError, match="workflow_factory"): diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index 235a16c6e1c..7c25d58ccd3 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -21,7 +21,6 @@ handler, response_handler, ) -from typing_extensions import Never from agent_framework_ag_ui._workflow_run import ( _coerce_content, @@ -52,14 +51,14 @@ class ProgressEvent(WorkflowEvent): """Custom workflow event used to validate CUSTOM mapping.""" def __init__(self, progress: int) -> None: - super().__init__("custom_progress", data={"progress": progress}) + super().__init__(cast(Any, "custom_progress"), data={"progress": progress}) async def test_workflow_run_maps_custom_and_text_events(): """Custom workflow events and yielded text are mapped to AG-UI events.""" @executor(id="start") - async def start(message: Any, ctx: WorkflowContext[Never, str]) -> None: + async def start(message: Any, ctx: WorkflowContext[Any, str]) -> None: await ctx.add_event(ProgressEvent(10)) await ctx.yield_output("Hello workflow") @@ -76,9 +75,9 @@ async def start(message: Any, ctx: WorkflowContext[Never, str]) -> None: assert "STEP_FINISHED" in event_types assert "RUN_FINISHED" in event_types - custom_events = [event for event in events if event.type == "CUSTOM" and event.name == "custom_progress"] + custom_events = [event for event in events if event.type == "CUSTOM" and event.name == "custom_progress"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert len(custom_events) == 1 - assert custom_events[0].value == {"progress": 10} + assert custom_events[0].value == {"progress": 10} # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_workflow_run_request_info_emits_interrupt_and_resume_works(): @@ -90,7 +89,7 @@ async def requester(message: Any, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=requester).build() - first_run_events = [ + first_run_events: list[Any] = [ event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) ] @@ -103,7 +102,7 @@ async def requester(message: Any, ctx: WorkflowContext) -> None: request_id = str(interrupt_payload[0]["id"]) assert request_id - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [], "resume": {"interrupts": [{"id": request_id, "value": "approved"}]}}, @@ -123,7 +122,7 @@ async def test_workflow_run_request_info_closes_open_text_message() -> None: @executor(id="requester") async def requester(message: Any, ctx: WorkflowContext) -> None: del message - await ctx.yield_output("Please confirm this action.") + await ctx.yield_output("Please confirm this action.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] await ctx.request_info("Need approval", str, request_id="approval-1") workflow = WorkflowBuilder(start_executor=requester).build() @@ -178,7 +177,7 @@ async def requester(message: Any, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=requester).build() _ = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -208,7 +207,7 @@ async def requester(message: Any, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=requester).build() _ = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -246,7 +245,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: del original_request status = "approved" if bool(response.approved) else "rejected" - await ctx.yield_output(f"Refund tool call {status}.") + await ctx.yield_output(f"Refund tool call {status}.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() first_events = [ @@ -256,7 +255,7 @@ async def handle_approval(self, original_request: Content, response: Content, ct interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) interrupt_value = cast(dict[str, Any], interrupt_payload[0]["value"]) - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -306,12 +305,12 @@ async def handle_user_input( ) -> None: del original_request user_text = response[0].text if response else "" - await ctx.yield_output(f"Captured response: {user_text}") + await ctx.yield_output(f"Captured response: {user_text}") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=MessageRequestExecutor()).build() _ = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "start"}]}, workflow)] - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -347,22 +346,22 @@ async def test_workflow_run_non_chat_output_maps_to_custom_output_event(): """Non-chat workflow outputs are emitted as CUSTOM workflow_output events.""" @executor(id="structured") - async def structured(message: Any, ctx: WorkflowContext[Never, dict[str, int]]) -> None: + async def structured(message: Any, ctx: WorkflowContext[Any, dict[str, int]]) -> None: await ctx.yield_output({"count": 3}) workflow = WorkflowBuilder(start_executor=structured).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - output_custom = [event for event in events if event.type == "CUSTOM" and event.name == "workflow_output"] + output_custom = [event for event in events if event.type == "CUSTOM" and event.name == "workflow_output"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert len(output_custom) == 1 - assert output_custom[0].value == {"count": 3} + assert output_custom[0].value == {"count": 3} # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_workflow_run_passthroughs_ag_ui_base_events(): """Workflow outputs that are AG-UI BaseEvent instances should be emitted directly.""" @executor(id="stateful") - async def stateful(message: Any, ctx: WorkflowContext[Never, StateSnapshotEvent]) -> None: + async def stateful(message: Any, ctx: WorkflowContext[Any, StateSnapshotEvent]) -> None: await ctx.yield_output(StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot={"active_agent": "flights"})) workflow = WorkflowBuilder(start_executor=stateful).build() @@ -370,7 +369,7 @@ async def stateful(message: Any, ctx: WorkflowContext[Never, StateSnapshotEvent] snapshots = [event for event in events if event.type == "STATE_SNAPSHOT"] assert len(snapshots) == 1 - assert snapshots[0].snapshot["active_agent"] == "flights" + assert snapshots[0].snapshot["active_agent"] == "flights" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_workflow_run_plain_text_follow_up_does_not_infer_interrupt_response(): @@ -453,7 +452,7 @@ async def test_workflow_run_agent_response_output_uses_latest_assistant_message_ """Conversation payload outputs should not flatten full history into one assistant message.""" @executor(id="responder") - async def responder(message: Any, ctx: WorkflowContext[Never, AgentResponse]) -> None: + async def responder(message: Any, ctx: WorkflowContext[Any, AgentResponse]) -> None: del message response = AgentResponse( messages=[ @@ -469,7 +468,7 @@ async def responder(message: Any, ctx: WorkflowContext[Never, AgentResponse]) -> workflow = WorkflowBuilder(start_executor=responder).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] + text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert text_deltas == ["Order Agent: Got it. I submitted the replacement request."] @@ -477,7 +476,7 @@ async def test_workflow_run_skips_duplicate_text_from_conversation_snapshot() -> """Do not emit duplicate assistant text when a snapshot repeats the latest output.""" @executor(id="responder") - async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: del message duplicate_text = "Order Agent: Got it. I submitted the replacement request." await ctx.yield_output(duplicate_text) @@ -493,7 +492,7 @@ async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: workflow = WorkflowBuilder(start_executor=responder).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] + text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert text_deltas == ["Order Agent: Got it. I submitted the replacement request."] @@ -501,7 +500,7 @@ async def test_workflow_run_skips_consecutive_duplicate_text_outputs() -> None: """Do not emit duplicate assistant text when consecutive outputs are identical.""" @executor(id="responder") - async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: del message duplicate_text = "Order Agent: Replacement processed. Case complete." await ctx.yield_output(duplicate_text) @@ -510,7 +509,7 @@ async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: workflow = WorkflowBuilder(start_executor=responder).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] + text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert text_deltas == ["Order Agent: Replacement processed. Case complete."] @@ -518,7 +517,7 @@ async def test_workflow_run_skips_final_snapshot_when_streamed_chunks_already_ma """Do not append full snapshot text when prior chunk outputs already formed the same message.""" @executor(id="responder") - async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + async def responder(message: Any, ctx: WorkflowContext[Any, Any]) -> None: del message full_text = ( "Your replacement request for order 28939393 has been submitted with expedited shipping, " @@ -540,7 +539,7 @@ async def responder(message: Any, ctx: WorkflowContext[Never, Any]) -> None: workflow = WorkflowBuilder(start_executor=responder).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] + text_deltas = [event.delta for event in events if event.type == "TEXT_MESSAGE_CONTENT"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert text_deltas == [ "Your replacement request for order 28939393 has been submitted with expedited shipping, ", "as you requested.\n\nCase complete.", @@ -551,7 +550,7 @@ async def test_workflow_run_usage_content_emits_custom_usage_event() -> None: """Usage output from workflows should be surfaced as a custom usage event.""" @executor(id="usage") - async def usage(message: Any, ctx: WorkflowContext[Never, Content]) -> None: + async def usage(message: Any, ctx: WorkflowContext[Any, Content]) -> None: del message await ctx.yield_output( Content.from_usage( @@ -566,11 +565,11 @@ async def usage(message: Any, ctx: WorkflowContext[Never, Content]) -> None: workflow = WorkflowBuilder(start_executor=usage).build() events = [event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow)] - usage_events = [event for event in events if event.type == "CUSTOM" and event.name == "usage"] + usage_events = [event for event in events if event.type == "CUSTOM" and event.name == "usage"] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert len(usage_events) == 1 - assert usage_events[0].value["input_token_count"] == 12 - assert usage_events[0].value["output_token_count"] == 6 - assert usage_events[0].value["total_token_count"] == 18 + assert usage_events[0].value["input_token_count"] == 12 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert usage_events[0].value["output_token_count"] == 6 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert usage_events[0].value["total_token_count"] == 18 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_workflow_run_accepts_multimodal_input_messages() -> None: @@ -589,7 +588,7 @@ async def _stream(): return _stream() workflow = CapturingWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( { @@ -684,7 +683,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, @@ -963,7 +962,7 @@ def test_failed_coercion_skipped(self): def test_unknown_request_id_preserved(self): """Responses for unknown request IDs are preserved as-is.""" responses = {"unknown_id": "value"} - pending = {} + pending = {} # type: ignore[var-annotated] result = _coerce_responses_for_pending_requests(responses, pending) assert result == {"unknown_id": "value"} @@ -1054,7 +1053,7 @@ def test_string_data(self): def test_dict_data_serialized(self): """Dict data is JSON-serialized.""" result = _workflow_interrupt_event_value({"data": {"key": "val"}}) - assert json.loads(result) == {"key": "val"} + assert json.loads(result) == {"key": "val"} # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] class TestWorkflowPayloadToContents: @@ -1306,7 +1305,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: del original_request status = "approved" if bool(response.approved) else "rejected" - await ctx.yield_output(f"Refund {status}.") + await ctx.yield_output(f"Refund {status}.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() first_events = [ @@ -1317,7 +1316,7 @@ async def handle_approval(self, original_request: Content, response: Content, ct assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 # Second turn: send approval via function_approvals on a message (not resume.interrupts) - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1377,7 +1376,7 @@ async def handle_approval(self, original_request: Content, response: Content, ct del original_request if response.function_call is not None: handled_responses.append(response.function_call.parse_arguments() or {}) - await ctx.yield_output("handled") + await ctx.yield_output("handled") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() first_events = [ @@ -1387,7 +1386,7 @@ async def handle_approval(self, original_request: Content, response: Content, ct interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1438,7 +1437,7 @@ async def start(self, message: Any, ctx: WorkflowContext) -> None: async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: del original_request status = "approved" if bool(response.approved) else "rejected" - await ctx.yield_output(f"Delete {status}.") + await ctx.yield_output(f"Delete {status}.") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() first_events = [ @@ -1449,7 +1448,7 @@ async def handle_approval(self, original_request: Content, response: Content, ct assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 # Second turn: send denial via function_approvals on a message (not resume.interrupts) - resumed_events = [ + resumed_events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1517,7 +1516,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, FailingWorkflow()) @@ -1546,7 +1545,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, StatusWorkflow()) @@ -1571,7 +1570,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ExecutorWorkflow()) @@ -1599,7 +1598,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ExecutorFailWorkflow()) @@ -1629,7 +1628,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ListEventWorkflow()) @@ -1653,7 +1652,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, EmptyWorkflow()) @@ -1676,7 +1675,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, DualTextWorkflow()) @@ -1700,7 +1699,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, SuperstepWorkflow()) @@ -1726,7 +1725,7 @@ async def _stream(): return _stream() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "go"}]}, cast(Any, StatusWorkflow()) @@ -1754,7 +1753,7 @@ async def _stream(): return _stream() workflow = CapturingWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1792,7 +1791,7 @@ async def _stream(): return _stream() workflow = CapturingWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( {"messages": [{"role": "user", "content": "hello"}]}, @@ -1822,7 +1821,7 @@ async def _stream(): return _stream() workflow = CapturingWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1859,7 +1858,7 @@ async def _stream(): return _stream() workflow = CapturingWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( { @@ -1927,7 +1926,7 @@ async def _stream(): return _stream() workflow = StrictWorkflow() - events = [ + events: list[Any] = [ event async for event in run_workflow_stream( { diff --git a/python/packages/anthropic/agent_framework_anthropic/_bedrock_client.py b/python/packages/anthropic/agent_framework_anthropic/_bedrock_client.py index 6c4c32738c3..af675dcb7fd 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_bedrock_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_bedrock_client.py @@ -34,7 +34,7 @@ class AnthropicBedrockSettings(TypedDict, total=False): class RawAnthropicBedrockClient(RawAnthropicClient[AnthropicOptionsT], Generic[AnthropicOptionsT]): """Raw Anthropic Bedrock chat client without middleware, telemetry, or function invocation support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" def __init__( self, @@ -105,7 +105,7 @@ def __init__( ) -class AnthropicBedrockClient( # type: ignore[misc] +class AnthropicBedrockClient( FunctionInvocationLayer[AnthropicOptionsT], ChatMiddlewareLayer[AnthropicOptionsT], ChatTelemetryLayer[AnthropicOptionsT], diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index c90b061b4ff..ed7c11ebdd4 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -52,17 +52,17 @@ from pydantic import BaseModel if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover __all__ = [ @@ -243,7 +243,7 @@ class RawAnthropicClient( Use ``AnthropicClient`` instead for a fully-featured client with all layers applied. """ - OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "anthropic" def __init__( self, @@ -539,7 +539,7 @@ def _inner_get_response( if stream: # Streaming mode async def _stream() -> AsyncIterable[ChatResponseUpdate]: - async for chunk in await self.anthropic_client.beta.messages.create(**run_options, stream=True): # type: ignore[misc] + async for chunk in await self.anthropic_client.beta.messages.create(**run_options, stream=True): parsed_chunk = self._process_stream_event(chunk) if parsed_chunk: yield parsed_chunk @@ -548,7 +548,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: # Non-streaming mode async def _get_response() -> ChatResponse: - message = await self.anthropic_client.beta.messages.create(**run_options, stream=False) # type: ignore[misc] + message = await self.anthropic_client.beta.messages.create(**run_options, stream=False) return self._process_message(message, options) return _get_response() @@ -723,7 +723,7 @@ def _prepare_message_for_anthropic(self, message: Message) -> dict[str, Any]: a_content.append({ "type": "image", "source": { - "data": _get_data_bytes_as_str(content), # type: ignore[attr-defined] + "data": _get_data_bytes_as_str(content), "media_type": content.media_type, "type": "base64", }, @@ -755,7 +755,7 @@ def _prepare_message_for_anthropic(self, message: Message) -> dict[str, Any]: tool_content.append({ "type": "image", "source": { - "data": _get_data_bytes_as_str(item), # type: ignore[attr-defined] + "data": _get_data_bytes_as_str(item), "media_type": item.media_type, "type": "base64", }, @@ -1023,9 +1023,9 @@ def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage | if usage.input_tokens is not None: usage_details["input_token_count"] = usage.input_tokens if usage.cache_creation_input_tokens is not None: - usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens # type: ignore[typeddict-unknown-key] + usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens if usage.cache_read_input_tokens is not None: - usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens # type: ignore[typeddict-unknown-key] + usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens return usage_details def _parse_contents_from_anthropic( diff --git a/python/packages/anthropic/agent_framework_anthropic/_foundry_client.py b/python/packages/anthropic/agent_framework_anthropic/_foundry_client.py index 997aac0e442..2a9db1177ff 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_foundry_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_foundry_client.py @@ -33,7 +33,7 @@ class AnthropicFoundrySettings(TypedDict, total=False): class RawAnthropicFoundryClient(RawAnthropicClient[AnthropicOptionsT], Generic[AnthropicOptionsT]): """Raw Anthropic Foundry chat client without middleware, telemetry, or function invocation support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" def __init__( self, @@ -109,7 +109,7 @@ def __init__( ) -class AnthropicFoundryClient( # type: ignore[misc] +class AnthropicFoundryClient( FunctionInvocationLayer[AnthropicOptionsT], ChatMiddlewareLayer[AnthropicOptionsT], ChatTelemetryLayer[AnthropicOptionsT], diff --git a/python/packages/anthropic/agent_framework_anthropic/_vertex_client.py b/python/packages/anthropic/agent_framework_anthropic/_vertex_client.py index 7af22be0f00..d95a8e84a1e 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_vertex_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_vertex_client.py @@ -34,7 +34,7 @@ class AnthropicVertexSettings(TypedDict, total=False): class RawAnthropicVertexClient(RawAnthropicClient[AnthropicOptionsT], Generic[AnthropicOptionsT]): """Raw Anthropic Vertex chat client without middleware, telemetry, or function invocation support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "google.vertex.ai" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "google.vertex.ai" def __init__( self, @@ -100,7 +100,7 @@ def __init__( ) -class AnthropicVertexClient( # type: ignore[misc] +class AnthropicVertexClient( FunctionInvocationLayer[AnthropicOptionsT], ChatMiddlewareLayer[AnthropicOptionsT], ChatTelemetryLayer[AnthropicOptionsT], diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index abad158b8c0..6ae6bcc68b5 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -2,7 +2,7 @@ import os import re from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast from unittest.mock import MagicMock, patch import pytest @@ -64,7 +64,7 @@ def create_test_anthropic_client( client._last_call_id_name = None client._tool_name_aliases = {} client.additional_properties = {} - client.middleware = None + cast(Any, client).middleware = None client.additional_beta_flags = [] client.chat_middleware = [] client.function_middleware = [] @@ -1417,7 +1417,7 @@ async def mock_stream(): chat_options = ChatOptions(max_tokens=10) chunks: list[ChatResponseUpdate] = [] - async for chunk in client._inner_get_response( # type: ignore[attr-defined] + async for chunk in client._inner_get_response( # type: ignore[attr-defined] # ty: ignore[not-iterable] messages=messages, options=chat_options, stream=True ): if chunk: @@ -1443,7 +1443,7 @@ async def mock_stream(): messages = [Message(role="user", contents=["Hi"])] options: dict[str, Any] = {"max_tokens": 10, "stream": False} - async for _ in client._inner_get_response( # type: ignore[attr-defined] + async for _ in client._inner_get_response( # type: ignore[attr-defined] # ty: ignore[not-iterable] messages=messages, options=options, stream=True, @@ -2262,7 +2262,7 @@ def test_prepare_options_missing_model(mock_anthropic_client: MagicMock) -> None client.model = "" # Set empty model messages = [Message(role="user", contents=[Content.from_text("Hello")])] - options = {} + options: dict[str, Any] = {} try: client._prepare_options(messages, options) @@ -2352,8 +2352,8 @@ def test_parse_usage_with_cache_tokens(mock_anthropic_client: MagicMock) -> None assert result is not None assert result["output_token_count"] == 50 assert result["input_token_count"] == 100 - assert result["anthropic.cache_creation_input_tokens"] == 20 - assert result["anthropic.cache_read_input_tokens"] == 30 + assert result["anthropic.cache_creation_input_tokens"] == 20 # ty: ignore[invalid-key] + assert result["anthropic.cache_read_input_tokens"] == 30 # ty: ignore[invalid-key] # Code Execution Result Tests @@ -2722,6 +2722,7 @@ def test_parse_citations_char_location(mock_anthropic_client: MagicMock) -> None result = client._parse_citations_from_anthropic(mock_block) + assert result is not None assert len(result) > 0 @@ -2745,6 +2746,7 @@ def test_parse_citations_page_location(mock_anthropic_client: MagicMock) -> None result = client._parse_citations_from_anthropic(mock_block) + assert result is not None assert len(result) > 0 @@ -2770,6 +2772,7 @@ def test_parse_citations_content_block_location( result = client._parse_citations_from_anthropic(mock_block) + assert result is not None assert len(result) > 0 @@ -2792,6 +2795,7 @@ def test_parse_citations_web_search_location(mock_anthropic_client: MagicMock) - result = client._parse_citations_from_anthropic(mock_block) + assert result is not None assert len(result) > 0 @@ -2818,6 +2822,7 @@ def test_parse_citations_search_result_location( result = client._parse_citations_from_anthropic(mock_block) + assert result is not None assert len(result) > 0 diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 5a0b79f29db..9a7ced525a9 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -701,7 +701,7 @@ async def _semantic_search(self, query: str) -> list[Message]: embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType] query_vector = embeddings[0].vector # type: ignore[reportUnknownVariableType] else: - query_vector = await self.embedding_function(query) # type: ignore[reportUnknownVariableType] + query_vector = await self.embedding_function(query) vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] # type: ignore[reportUnknownArgumentType] search_params: dict[str, Any] = {"search_text": query, "top": self.top_k} @@ -721,7 +721,7 @@ async def _semantic_search(self, query: str) -> list[Message]: doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType] doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType] if doc_text: - result_messages.append(Message(role="user", contents=[doc_text])) # type: ignore[reportUnknownArgumentType] + result_messages.append(Message(role="user", contents=[doc_text])) return result_messages async def _ensure_knowledge_base(self) -> None: diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 64c12e0724c..bf3c48d1674 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -3,6 +3,7 @@ import os from types import SimpleNamespace +from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch import pytest @@ -82,9 +83,9 @@ async def _search(**kwargs): return client -def _make_provider(**overrides) -> AzureAISearchContextProvider: +def _make_provider(**overrides: Any) -> AzureAISearchContextProvider: """Create a semantic-mode provider with mocked internals (skips auto-discovery).""" - defaults = { + defaults: dict[str, Any] = { "source_id": AzureAISearchContextProvider.DEFAULT_SOURCE_ID, "endpoint": "https://test.search.windows.net", "index_name": "test-index", @@ -207,7 +208,7 @@ class TestInitAgenticValidation: def test_both_index_and_kb_raises(self) -> None: with pytest.raises(SettingNotFoundError, match="multiple were set"): - AzureAISearchContextProvider( + cast(Any, AzureAISearchContextProvider)( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", @@ -229,7 +230,7 @@ def test_neither_index_nor_kb_raises(self) -> None: def test_missing_model_raises(self) -> None: with pytest.raises(ValueError, match="model"): - AzureAISearchContextProvider( + cast(Any, AzureAISearchContextProvider)( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", @@ -250,7 +251,7 @@ def test_vector_field_without_embedding_raises(self) -> None: def test_agentic_missing_aoai_url_with_index_raises(self) -> None: with pytest.raises(ValueError, match="azure_openai_resource_url"): - AzureAISearchContextProvider( + cast(Any, AzureAISearchContextProvider)( source_id="s", endpoint="https://test.search.windows.net", index_name="idx", @@ -359,7 +360,10 @@ async def test_results_added_to_context(self, mock_search_client: AsyncMock) -> session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client.search.assert_awaited_once() @@ -374,7 +378,10 @@ async def test_empty_input_no_search(self, mock_search_client: AsyncMock) -> Non session = AgentSession(session_id="test-session") ctx = SessionContext(input_messages=[], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client.search.assert_not_awaited() @@ -390,7 +397,10 @@ async def test_no_results_no_messages(self, mock_search_client_empty: AsyncMock) session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client_empty.search.assert_awaited_once() @@ -407,7 +417,10 @@ async def test_context_prompt_prepended(self, mock_search_client: AsyncMock) -> session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] msgs = ctx.context_messages[provider.source_id] @@ -433,7 +446,10 @@ async def test_filters_non_user_assistant(self, mock_search_client: AsyncMock) - session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client.search.assert_awaited_once() @@ -452,7 +468,10 @@ async def test_only_system_messages_no_search(self, mock_search_client: AsyncMoc session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client.search.assert_not_awaited() @@ -467,7 +486,10 @@ async def test_whitespace_only_messages_filtered(self, mock_search_client: Async session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_search_client.search.assert_not_awaited() @@ -486,7 +508,10 @@ async def test_assistant_messages_included(self, mock_search_client: AsyncMock) session_id="s1", ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] call_kwargs = mock_search_client.search.call_args[1] @@ -1373,6 +1398,9 @@ def test_fallback_to_msg_text_when_no_contents(self) -> None: msg = Message(role="user", contents=["fallback text"]) result = AzureAISearchContextProvider._prepare_messages_for_kb_search([msg]) assert len(result) == 1 + from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent + + assert isinstance(result[0].content[0], KnowledgeBaseMessageTextContent) assert result[0].content[0].text == "fallback text" def test_data_uri_image(self) -> None: @@ -1487,7 +1515,7 @@ def test_multiple_references(self) -> None: KnowledgeBaseSearchIndexReference(id="ref-a", activity_source=0), KnowledgeBaseWebReference(id="ref-b", activity_source=1, url="https://example.com"), ] - result = AzureAISearchContextProvider._parse_references_to_annotations(refs) + result = AzureAISearchContextProvider._parse_references_to_annotations(refs) # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert len(result) == 2 assert result[0]["additional_properties"]["activity_source"] == 0 assert result[1]["additional_properties"]["activity_source"] == 1 @@ -1680,7 +1708,10 @@ async def test_agentic_mode_calls_agentic_search(self) -> None: type(mock_content), ): await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] msgs = ctx.context_messages.get(provider.source_id, []) diff --git a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_context_provider.py b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_context_provider.py index 3271d2a3acb..11b1d9f034e 100644 --- a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_context_provider.py +++ b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_context_provider.py @@ -632,7 +632,7 @@ async def _resolve_pending_tokens( try: poller = await self._client.begin_analyze( # type: ignore[call-overload, reportUnknownVariableType] token_info["analyzer_id"], - continuation_token=token_info["continuation_token"], # pyright: ignore[reportCallIssue] + continuation_token=token_info["continuation_token"], ) # Use wait_for to avoid blocking before_run indefinitely. # poller.done() always returns False for resumed pollers (stale @@ -649,7 +649,7 @@ async def _resolve_pending_tokens( result: AnalysisResult = await asyncio.wait_for( poller.result(), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] timeout=resolution_timeout, - ) # pyright: ignore[reportUnknownVariableType] + ) except asyncio.TimeoutError: # Still running — update token and keep for next turn new_token: str = poller.continuation_token() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] @@ -658,7 +658,7 @@ async def _resolve_pending_tokens( continue completed_keys.append(doc_key) - extracted = self._extract_sections(result) # pyright: ignore[reportUnknownArgumentType] + extracted = self._extract_sections(result) entry["status"] = DocumentStatus.READY entry["analyzed_at"] = datetime.now(tz=timezone.utc).isoformat() entry["result"] = extracted diff --git a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_detection.py b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_detection.py index 75ee88d7d32..c3454f6eea4 100644 --- a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_detection.py +++ b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_detection.py @@ -143,7 +143,7 @@ def sniff_media_type(binary_data: bytes | None, content: Content) -> str | None: if binary_data: kind = filetype.guess(binary_data[:262]) # type: ignore[reportUnknownMemberType] if kind: - mime: str = kind.mime # type: ignore[reportUnknownMemberType] + mime: str = kind.mime return MIME_ALIASES.get(mime, mime) # 2. Filename extension fallback — try additional_properties first, diff --git a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_file_search.py b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_file_search.py index a9526f6ebcd..51ec31a5497 100644 --- a/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_file_search.py +++ b/python/packages/azure-contentunderstanding/agent_framework_azure_contentunderstanding/_file_search.py @@ -66,7 +66,7 @@ async def upload_file(self, vector_store_id: str, filename: str, content: bytes) vector_store_id=vector_store_id, file_id=uploaded.id, ) - return uploaded.id # type: ignore[no-any-return] + return uploaded.id async def delete_file(self, file_id: str) -> None: await self._client.files.delete(file_id) diff --git a/python/packages/azure-contentunderstanding/tests/cu/test_context_provider.py b/python/packages/azure-contentunderstanding/tests/cu/test_context_provider.py index 0e0dae439fe..fd4a86831b2 100644 --- a/python/packages/azure-contentunderstanding/tests/cu/test_context_provider.py +++ b/python/packages/azure-contentunderstanding/tests/cu/test_context_provider.py @@ -5,7 +5,7 @@ import asyncio import base64 import json -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock from agent_framework import Content, Message, SessionContext @@ -508,7 +508,7 @@ async def test_returns_all_docs_with_status( class TestOutputFiltering: def test_default_markdown_and_fields(self, pdf_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(pdf_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(pdf_analysis_result)) assert "markdown" in result assert "fields" in result @@ -516,14 +516,14 @@ def test_default_markdown_and_fields(self, pdf_analysis_result: AnalysisResult) def test_markdown_only(self, pdf_analysis_result: AnalysisResult) -> None: provider = _make_provider(output_sections=["markdown"]) - result = provider._extract_sections(pdf_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(pdf_analysis_result)) assert "markdown" in result assert "fields" not in result def test_fields_only(self, invoice_analysis_result: AnalysisResult) -> None: provider = _make_provider(output_sections=["fields"]) - result = provider._extract_sections(invoice_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(invoice_analysis_result)) assert "markdown" not in result assert "fields" in result @@ -533,9 +533,9 @@ def test_fields_only(self, invoice_analysis_result: AnalysisResult) -> None: def test_field_values_extracted(self, invoice_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(invoice_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(invoice_analysis_result)) - fields = result.get("fields") + fields = cast("dict[str, Any]", result.get("fields")) assert isinstance(fields, dict) assert "VendorName" in fields assert fields["VendorName"]["value"] is not None @@ -549,8 +549,8 @@ def test_invoice_field_extraction_matches_expected(self, invoice_analysis_result a glance. Confidence is only present when the CU service provides it. """ provider = _make_provider() - result = provider._extract_sections(invoice_analysis_result) - fields = result.get("fields") + result = cast("dict[str, Any]", provider._extract_sections(invoice_analysis_result)) + fields = cast("dict[str, Any]", result.get("fields")) expected_fields = { "VendorName": { @@ -1029,19 +1029,19 @@ async def test_lazy_initialization_on_before_run(self) -> None: class TestMultiModalFixtures: def test_pdf_fixture_loads(self, pdf_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(pdf_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(pdf_analysis_result)) assert "markdown" in result assert "Contoso" in str(result["markdown"]) def test_audio_fixture_loads(self, audio_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(audio_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(audio_analysis_result)) assert "markdown" in result assert "Call Center" in str(result["markdown"]) def test_video_fixture_loads(self, video_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(video_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(video_analysis_result)) assert "markdown" in result # All 3 segments should be concatenated at top level (for file_search) md = str(result["markdown"]) @@ -1073,12 +1073,12 @@ def test_video_fixture_loads(self, video_analysis_result: AnalysisResult) -> Non def test_image_fixture_loads(self, image_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(image_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(image_analysis_result)) assert "markdown" in result def test_invoice_fixture_loads(self, invoice_analysis_result: AnalysisResult) -> None: provider = _make_provider() - result = provider._extract_sections(invoice_analysis_result) + result = cast("dict[str, Any]", provider._extract_sections(invoice_analysis_result)) assert "markdown" in result assert "fields" in result fields = result["fields"] @@ -1625,7 +1625,7 @@ async def test_close_cleans_up(self) -> None: await provider.close() # Client should be closed (no tasks to cancel — tokens are just strings) - provider._client.close.assert_called_once() + cast(Any, provider._client.close).assert_called_once() class TestSessionIsolation: @@ -1895,7 +1895,7 @@ def test_warnings_included_when_present(self) -> None: ], } result_obj = AnalysisResult(fixture) - extracted = provider._extract_sections(result_obj) + extracted = cast("dict[str, Any]", provider._extract_sections(result_obj)) assert "warnings" in extracted warnings = extracted["warnings"] assert isinstance(warnings, list) @@ -1912,7 +1912,7 @@ def test_warnings_included_when_present(self) -> None: def test_warnings_omitted_when_empty(self, pdf_analysis_result: AnalysisResult) -> None: """Empty/None warnings should not appear in extracted result.""" provider = _make_provider() - extracted = provider._extract_sections(pdf_analysis_result) + extracted = cast("dict[str, Any]", provider._extract_sections(pdf_analysis_result)) assert "warnings" not in extracted @@ -1933,7 +1933,7 @@ def test_category_included_single_segment(self) -> None: ], } result_obj = AnalysisResult(fixture) - extracted = provider._extract_sections(result_obj) + extracted = cast("dict[str, Any]", provider._extract_sections(result_obj)) assert extracted.get("category") == "Legal Contract" def test_category_in_multi_segment_video(self) -> None: @@ -1972,7 +1972,7 @@ def test_category_in_multi_segment_video(self) -> None: ], } result_obj = AnalysisResult(fixture) - extracted = provider._extract_sections(result_obj) + extracted = cast("dict[str, Any]", provider._extract_sections(result_obj)) # Top-level metadata assert extracted["kind"] == "audioVisual" @@ -2003,7 +2003,7 @@ def test_category_in_multi_segment_video(self) -> None: def test_category_omitted_when_none(self, pdf_analysis_result: AnalysisResult) -> None: """No category should be in output when analyzer doesn't classify.""" provider = _make_provider() - extracted = provider._extract_sections(pdf_analysis_result) + extracted = cast("dict[str, Any]", provider._extract_sections(pdf_analysis_result)) assert "category" not in extracted diff --git a/python/packages/azure-contentunderstanding/tests/cu/test_integration.py b/python/packages/azure-contentunderstanding/tests/cu/test_integration.py index 0e204e25077..b16bfa5c101 100644 --- a/python/packages/azure-contentunderstanding/tests/cu/test_integration.py +++ b/python/packages/azure-contentunderstanding/tests/cu/test_integration.py @@ -14,6 +14,7 @@ import json import os from pathlib import Path +from typing import Any, cast import pytest @@ -43,11 +44,12 @@ async def test_analyze_pdf_binary() -> None: assert pdf_path.exists(), f"Test fixture not found: {pdf_path}" pdf_bytes = pdf_path.read_bytes() - async with DefaultAzureCredential() as credential, ContentUnderstandingClient(endpoint, credential) as client: + async with DefaultAzureCredential() as credential, ContentUnderstandingClient(endpoint, credential) as client: # pyrefly: ignore[bad-argument-type] poller = await client.begin_analyze_binary( analyzer_id, binary_input=pdf_bytes, content_type="application/pdf", + string_encoding="utf-8", ) result = await poller.result() @@ -83,7 +85,7 @@ async def test_before_run_e2e() -> None: async with DefaultAzureCredential() as credential: cu = ContentUnderstandingContextProvider( endpoint=endpoint, - credential=credential, + credential=credential, # pyrefly: ignore[bad-argument-type] max_wait=None, # wait until analysis completes (no background deferral) ) async with cu: @@ -106,7 +108,7 @@ async def test_before_run_e2e() -> None: await cu.before_run(agent=MagicMock(), session=session, context=context, state=state) - docs = state.get("documents", {}) + docs = cast("dict[str, Any]", state.get("documents", {})) assert isinstance(docs, dict) assert "invoice.pdf" in docs doc_entry = docs["invoice.pdf"] @@ -143,7 +145,7 @@ async def test_before_run_uri_content() -> None: async with DefaultAzureCredential() as credential: cu = ContentUnderstandingContextProvider( endpoint=endpoint, - credential=credential, + credential=credential, # pyrefly: ignore[bad-argument-type] max_wait=None, # wait until analysis completes (no background deferral) ) async with cu: @@ -166,7 +168,7 @@ async def test_before_run_uri_content() -> None: await cu.before_run(agent=MagicMock(), session=session, context=context, state=state) - docs = state.get("documents", {}) + docs = cast("dict[str, Any]", state.get("documents", {})) assert isinstance(docs, dict) assert "invoice.pdf" in docs @@ -206,7 +208,7 @@ async def test_before_run_data_uri_content() -> None: async with DefaultAzureCredential() as credential: cu = ContentUnderstandingContextProvider( endpoint=endpoint, - credential=credential, + credential=credential, # pyrefly: ignore[bad-argument-type] max_wait=None, # wait until analysis completes ) async with cu: @@ -229,7 +231,7 @@ async def test_before_run_data_uri_content() -> None: await cu.before_run(agent=MagicMock(), session=session, context=context, state=state) - docs = state.get("documents", {}) + docs = cast("dict[str, Any]", state.get("documents", {})) assert isinstance(docs, dict) assert "invoice_b64.pdf" in docs @@ -264,7 +266,7 @@ async def test_before_run_background_analysis() -> None: async with DefaultAzureCredential() as credential: cu = ContentUnderstandingContextProvider( endpoint=endpoint, - credential=credential, + credential=credential, # pyrefly: ignore[bad-argument-type] max_wait=0.5, # short timeout to force background deferral ) async with cu: @@ -288,7 +290,7 @@ async def test_before_run_background_analysis() -> None: await cu.before_run(agent=MagicMock(), session=session, context=context, state=state) - docs = state.get("documents", {}) + docs = cast("dict[str, Any]", state.get("documents", {})) assert isinstance(docs, dict) assert "invoice.pdf" in docs assert docs["invoice.pdf"]["status"] == "analyzing", ( diff --git a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py index e3ac636aa65..bcfa7430841 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py @@ -6,7 +6,7 @@ import uuid from collections.abc import AsyncIterator from contextlib import suppress -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -282,8 +282,11 @@ async def test_before_run_loads_history(self, mock_container: MagicMock) -> None context = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) - ) # type: ignore[arg-type] + agent=cast(Any, None), + session=session, + context=context, + state=session.state.setdefault(provider.source_id, {}), + ) assert "mem" in context.context_messages assert context.context_messages["mem"][0].text == "old msg" @@ -295,8 +298,11 @@ async def test_after_run_stores_input_and_response(self, mock_container: MagicMo context._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])]) await provider.after_run( - agent=None, session=session, context=context, state=session.state.setdefault(provider.source_id, {}) - ) # type: ignore[arg-type] + agent=cast(Any, None), + session=session, + context=context, + state=session.state.setdefault(provider.source_id, {}), + ) mock_container.execute_item_batch.assert_awaited_once() batch_operations = mock_container.execute_item_batch.await_args.kwargs["batch_operations"] diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index be6c2d015ea..5e2ee5938c2 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from azure.durable_functions import DurableOrchestrationContext - class _TypedCompoundTask(CompoundTask): # type: ignore[misc] + class _TypedCompoundTask(CompoundTask): _first_error: Any def __init__( @@ -44,7 +44,7 @@ def __init__( _TypedCompoundTask = CompoundTask -class PreCompletedTask(TaskBase): # type: ignore[misc] +class PreCompletedTask(TaskBase): """A simple task that is already completed with a result. Used for fire-and-forget mode where we want to return immediately diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 27730fb441e..c6d62f8892d 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -100,7 +100,7 @@ def strip_pickle_markers(data: Any) -> Any: return {k: strip_pickle_markers(v) for k, v in typed_dict.items()} if isinstance(data, list): - typed_list = cast(list[Any], data) # type: ignore[redundant-cast] + typed_list = cast(list[Any], data) return [strip_pickle_markers(item) for item in typed_list] return data diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 6fbdaf44f7e..31077f55eb3 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -357,7 +357,7 @@ def _process_agent_response( if isinstance(dumped, dict): structured_response = dumped # type: ignore[assignment] elif isinstance(agent_response.value, dict): - structured_response = agent_response.value # type: ignore[assignment] + structured_response = agent_response.value output_message = build_agent_executor_response( executor_id=executor_id, @@ -916,7 +916,7 @@ async def execute_hitl_response_handler( response = _deserialize_hitl_response(response_data, response_type_str) # Find the matching response handler - handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] + handler = executor._find_response_handler(original_request, response) if handler is None: logger.warning( diff --git a/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py index 66527134a5f..ddf6bc0f469 100644 --- a/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py +++ b/python/packages/azurefunctions/tests/integration_tests/test_09_workflow_shared_state.py @@ -45,7 +45,7 @@ def test_workflow_with_spam_email(self) -> None: spam_content = "URGENT! You have won $1,000,000! Click here to claim your prize now before it expires!" # Start orchestration with spam email - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", spam_content) + response = self.helper.post_text(f"{self.base_url}/api/workflow/run", spam_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data @@ -64,7 +64,7 @@ def test_workflow_with_legitimate_email(self) -> None: ) # Start orchestration with legitimate email - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", legitimate_content) + response = self.helper.post_text(f"{self.base_url}/api/workflow/run", legitimate_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data @@ -83,7 +83,7 @@ def test_workflow_with_phishing_email(self) -> None: ) # Start orchestration with phishing email - response = self.helper.post_json(f"{self.base_url}/api/workflow/run", phishing_content) + response = self.helper.post_text(f"{self.base_url}/api/workflow/run", phishing_content) assert response.status_code == 202 data = response.json() assert "instanceId" in data diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 61518fa44b7..ff0975947b6 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -1182,11 +1182,11 @@ def test_init_with_invalid_max_poll_retries(self) -> None: mock_agent.name = "TestAgent" # Test with invalid type - app = AgentFunctionApp(agents=[mock_agent], max_poll_retries="invalid") + app = AgentFunctionApp(agents=[mock_agent], max_poll_retries="invalid") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert app.max_poll_retries >= 1 # Should use default # Test with None - app2 = AgentFunctionApp(agents=[mock_agent], max_poll_retries=None) + app2 = AgentFunctionApp(agents=[mock_agent], max_poll_retries=None) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert app2.max_poll_retries >= 1 # Should use default def test_init_with_invalid_poll_interval_seconds(self) -> None: @@ -1195,11 +1195,11 @@ def test_init_with_invalid_poll_interval_seconds(self) -> None: mock_agent.name = "TestAgent" # Test with invalid type - app = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds="invalid") + app = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds="invalid") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert app.poll_interval_seconds > 0 # Should use default # Test with None - app2 = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds=None) + app2 = AgentFunctionApp(agents=[mock_agent], poll_interval_seconds=None) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert app2.poll_interval_seconds > 0 # Should use default def test_get_agent_raises_for_unregistered_agent(self) -> None: diff --git a/python/packages/azurefunctions/tests/test_orchestration.py b/python/packages/azurefunctions/tests/test_orchestration.py index a4ff77d89fb..385d66c3873 100644 --- a/python/packages/azurefunctions/tests/test_orchestration.py +++ b/python/packages/azurefunctions/tests/test_orchestration.py @@ -29,7 +29,7 @@ class _FakeTask(TaskBase): def __init__(self, task_id: int = 1): super().__init__(task_id, []) self._set_is_scheduled(False) - self.action_repr = [] + self.action_repr: list[Any] = [] # pyrefly: ignore[bad-override-mutable-attribute] self.state = TaskState.RUNNING @@ -90,7 +90,7 @@ def executor_with_uuid() -> tuple[Any, Mock, str]: executor = AzureFunctionsAgentExecutor(context) test_uuid_hex = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" - executor.generate_unique_id = Mock(return_value=test_uuid_hex) + executor.generate_unique_id = Mock(return_value=test_uuid_hex) # type: ignore[method-assign] # ty: ignore[invalid-assignment] return executor, context, test_uuid_hex @@ -112,7 +112,7 @@ def executor_with_multiple_uuids() -> tuple[Any, Mock, list[str]]: "dddddddd-dddd-dddd-dddd-dddddddddddd", "eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee", ] - executor.generate_unique_id = Mock(side_effect=uuid_hexes) + executor.generate_unique_id = Mock(side_effect=uuid_hexes) # type: ignore[method-assign] # ty: ignore[invalid-assignment] return executor, context, uuid_hexes diff --git a/python/packages/bedrock/agent_framework_bedrock/__init__.py b/python/packages/bedrock/agent_framework_bedrock/__init__.py index b2dc5115599..3fbf5c15cf5 100644 --- a/python/packages/bedrock/agent_framework_bedrock/__init__.py +++ b/python/packages/bedrock/agent_framework_bedrock/__init__.py @@ -2,8 +2,8 @@ import importlib.metadata -from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings # type: ignore -from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings # type: ignore +from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings +from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 2fd78877217..cf8b3c562ae 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -41,17 +41,17 @@ from pydantic import BaseModel if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover logger = logging.getLogger("agent_framework.bedrock") @@ -230,7 +230,7 @@ class BedrockChatClient( ): """Async chat client for Amazon Bedrock's Converse API with middleware, telemetry, and function invocation.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" def __init__( self, @@ -371,7 +371,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: parsed_response = self._process_converse_response(response, options) contents = list(parsed_response.messages[0].contents if parsed_response.messages else []) if parsed_response.usage_details: - contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type] + contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) raw_finish_reason = ( parsed_response.finish_reason if isinstance(parsed_response.finish_reason, str) else None ) @@ -745,7 +745,7 @@ def _parse_message_contents(self, content_blocks: Sequence[dict[str, Any]]) -> l Content.from_function_result( call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(), result=result_value, - exception=str(exception) if exception else None, # type: ignore[arg-type] + exception=str(exception) if exception else None, raw_representation=block, ) ) diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index 52f5126e3da..4b666dbc4f2 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -26,9 +26,9 @@ from botocore.config import Config as BotoConfig if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover logger = logging.getLogger("agent_framework.bedrock") @@ -143,7 +143,7 @@ def __init__( config=BotoConfig(user_agent_extra=get_user_agent()), ) - self.model: str = settings["embedding_model"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] + self.model: str = settings["embedding_model"] # type: ignore[assignment] self.region = resolved_region super().__init__(additional_properties=additional_properties) @@ -261,7 +261,7 @@ class BedrockEmbeddingClient( print(result[0].vector) """ - OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" def __init__( self, diff --git a/python/packages/bedrock/tests/bedrock/test_bedrock_embedding_client.py b/python/packages/bedrock/tests/bedrock/test_bedrock_embedding_client.py index afb32c41584..b16ec44790a 100644 --- a/python/packages/bedrock/tests/bedrock/test_bedrock_embedding_client.py +++ b/python/packages/bedrock/tests/bedrock/test_bedrock_embedding_client.py @@ -41,7 +41,7 @@ async def test_bedrock_embedding_construction() -> None: client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) assert client.model == "amazon.titan-embed-text-v2:0" assert client.region == "us-west-2" @@ -62,7 +62,7 @@ async def test_bedrock_embedding_get_embeddings() -> None: client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) result = await client.get_embeddings(["hello", "world"]) @@ -86,7 +86,7 @@ async def test_bedrock_embedding_get_embeddings_empty_input() -> None: client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) result = await client.get_embeddings([]) @@ -102,14 +102,14 @@ async def test_bedrock_embedding_get_embeddings_with_options() -> None: client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) options: BedrockEmbeddingOptions = { "dimensions": 5, "normalize": True, } - result = await client.get_embeddings(["hello"], options=options) + result = await client.get_embeddings(["hello"], options=options) # ty: ignore[invalid-argument-type] assert len(result) == 1 assert len(result[0].vector) == 5 @@ -125,9 +125,9 @@ async def test_bedrock_embedding_get_embeddings_no_model_raises() -> None: client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) - client.model = None # type: ignore[assignment] + client.model = None # type: ignore[assignment] # ty: ignore[invalid-assignment] with pytest.raises(ValueError, match="model is required"): await client.get_embeddings(["hello"]) @@ -138,7 +138,7 @@ async def test_bedrock_embedding_default_region() -> None: stub = _StubBedrockEmbeddingRuntime() client = BedrockEmbeddingClient( model="amazon.titan-embed-text-v2:0", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) assert client.region == "us-east-1" diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py index d226943256c..9e1b42ea251 100644 --- a/python/packages/bedrock/tests/test_bedrock_client.py +++ b/python/packages/bedrock/tests/test_bedrock_client.py @@ -36,7 +36,7 @@ def _make_client() -> BedrockChatClient: return BedrockChatClient( model="amazon.titan-text", region="us-west-2", - client=_StubBedrockRuntime(), + client=_StubBedrockRuntime(), # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) @@ -45,7 +45,7 @@ async def test_get_response_invokes_bedrock_runtime() -> None: client = BedrockChatClient( model="amazon.titan-text", region="us-west-2", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) messages = [ @@ -67,7 +67,7 @@ def test_build_request_requires_non_system_messages() -> None: client = BedrockChatClient( model="amazon.titan-text", region="us-west-2", - client=_StubBedrockRuntime(), + client=_StubBedrockRuntime(), # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) messages = [Message(role="system", contents=[Content.from_text(text="Only system text")])] diff --git a/python/packages/bedrock/tests/test_bedrock_structured_output.py b/python/packages/bedrock/tests/test_bedrock_structured_output.py index 7b39f67d69e..1c4f2aeca65 100644 --- a/python/packages/bedrock/tests/test_bedrock_structured_output.py +++ b/python/packages/bedrock/tests/test_bedrock_structured_output.py @@ -5,7 +5,7 @@ import copy import json from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from agent_framework import Content, Message @@ -70,7 +70,7 @@ def _make_client(response_text: str = "Bedrock says hi") -> tuple[BedrockChatCli client = BedrockChatClient( model="us.anthropic.claude-haiku-4-5-v1:0", region="us-east-1", - client=stub, + client=stub, # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) return client, stub @@ -248,7 +248,7 @@ def converse(self, **kwargs: Any) -> dict[str, Any]: client = BedrockChatClient( model="us.anthropic.claude-v2", region="us-east-1", - client=_FailingStubBedrockRuntime(), + client=_FailingStubBedrockRuntime(), # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] ) with pytest.raises(ValueError) as exc: @@ -366,14 +366,9 @@ async def test_non_outputconfig_validation_exception_propagates() -> None: "Message": "Invalid message format", } } - with ( - patch.object( - client, - "_bedrock_client", - **{"converse.side_effect": ClientError(error_response, "Converse")}, - ), - pytest.raises(ClientError), - ): + failing_client = MagicMock() + failing_client.converse.side_effect = ClientError(error_response, "Converse") + with patch.object(client, "_bedrock_client", failing_client), pytest.raises(ClientError): await client.get_response( messages=_user_messages(), options={"max_tokens": 100}, diff --git a/python/packages/chatkit/tests/test_converter.py b/python/packages/chatkit/tests/test_converter.py index 907a1ad0a9b..a630062d37e 100644 --- a/python/packages/chatkit/tests/test_converter.py +++ b/python/packages/chatkit/tests/test_converter.py @@ -6,7 +6,8 @@ import pytest from agent_framework import Message -from chatkit.types import UserMessageTextContent +from chatkit.types import InferenceOptions, UserMessageTextContent +from pydantic import AnyUrl from agent_framework_chatkit import ThreadItemConverter, simple_to_agent_input @@ -37,7 +38,7 @@ async def test_to_agent_input_with_text(self, converter): type="user_message", content=[UserMessageTextContent(text="Hello, how can you help me?")], attachments=[], - inference_options={}, + inference_options=InferenceOptions(), ) result = await converter.to_agent_input(input_item) @@ -60,7 +61,7 @@ async def test_to_agent_input_empty_text(self, converter): type="user_message", content=[UserMessageTextContent(text=" ")], attachments=[], - inference_options={}, + inference_options=InferenceOptions(), ) result = await converter.to_agent_input(input_item) @@ -79,7 +80,7 @@ async def test_to_agent_input_no_content(self, converter): type="user_message", content=[], attachments=[], - inference_options={}, + inference_options=InferenceOptions(), ) result = await converter.to_agent_input(input_item) @@ -101,7 +102,7 @@ async def test_to_agent_input_multiple_content_parts(self, converter): UserMessageTextContent(text="world!"), ], attachments=[], - inference_options={}, + inference_options=InferenceOptions(), ) result = await converter.to_agent_input(input_item) @@ -176,7 +177,7 @@ async def test_attachment_to_message_content_image_with_preview_url(self, conver name="photo.jpg", mime_type="image/jpeg", type="image", - preview_url="https://example.com/photo.jpg", + preview_url=AnyUrl("https://example.com/photo.jpg"), ) result = await converter.attachment_to_message_content(attachment) @@ -202,6 +203,7 @@ async def fetch_data(attachment_id: str) -> bytes: ) result = await converter.attachment_to_message_content(attachment) + assert result is not None assert result.type == "data" assert result.media_type == "application/pdf" @@ -216,7 +218,7 @@ async def test_to_agent_input_with_image_attachment(self): name="photo.jpg", mime_type="image/jpeg", type="image", - preview_url="https://example.com/photo.jpg", + preview_url=AnyUrl("https://example.com/photo.jpg"), ) input_item = UserMessageItem( @@ -226,7 +228,7 @@ async def test_to_agent_input_with_image_attachment(self): type="user_message", content=[UserMessageTextContent(text="Check out this photo!")], attachments=[attachment], - inference_options={}, + inference_options=InferenceOptions(), ) converter = ThreadItemConverter() @@ -266,7 +268,7 @@ async def test_to_agent_input_with_file_attachment_and_fetcher(self): type="user_message", content=[UserMessageTextContent(text="Here's the document")], attachments=[attachment], - inference_options={}, + inference_options=InferenceOptions(), ) # Create converter with data fetcher @@ -373,14 +375,14 @@ def test_widget_to_input(self, converter): from datetime import datetime from chatkit.types import WidgetItem - from chatkit.widgets import Card, Text + from chatkit.widgets import Card, Text # ty: ignore[deprecated] widget_item = WidgetItem( id="widget_1", thread_id="thread_1", created_at=datetime.now(), type="widget", - widget=Card(key="card1", children=[Text(value="Hello")]), + widget=Card(key="card1", children=[Text(value="Hello")]), # ty: ignore[deprecated] ) result = converter.widget_to_input(widget_item) @@ -411,7 +413,7 @@ async def test_simple_to_agent_input_with_text(self): type="user_message", content=[UserMessageTextContent(text="Test message")], attachments=[], - inference_options={}, + inference_options=InferenceOptions(), ) result = await simple_to_agent_input(input_item) diff --git a/python/packages/chatkit/tests/test_streaming.py b/python/packages/chatkit/tests/test_streaming.py index c26a9cb7acf..2cb81bd1085 100644 --- a/python/packages/chatkit/tests/test_streaming.py +++ b/python/packages/chatkit/tests/test_streaming.py @@ -2,10 +2,14 @@ """Tests for Agent Framework to ChatKit streaming utilities.""" +from collections.abc import AsyncIterator from unittest.mock import Mock from agent_framework import AgentResponseUpdate, Content from chatkit.types import ( + AssistantMessageContent, + AssistantMessageContentPartTextDelta, + AssistantMessageItem, ThreadItemAddedEvent, ThreadItemDoneEvent, ThreadItemUpdated, @@ -20,7 +24,7 @@ class TestStreamAgentResponse: async def test_stream_empty_response(self): """Test streaming empty response.""" - async def empty_stream(): + async def empty_stream() -> AsyncIterator[AgentResponseUpdate]: return yield # Make it a generator @@ -49,10 +53,13 @@ async def single_update_stream(): assert isinstance(events[2], ThreadItemDoneEvent) # Check delta event + assert isinstance(events[1].update, AssistantMessageContentPartTextDelta) assert events[1].update.delta == "Hello world" # Check final message content + assert isinstance(events[2].item, AssistantMessageItem) assert len(events[2].item.content) == 1 + assert isinstance(events[2].item.content[0], AssistantMessageContent) assert events[2].item.content[0].text == "Hello world" async def test_stream_multiple_text_updates(self): @@ -76,12 +83,16 @@ async def multiple_updates_stream(): assert isinstance(events[3], ThreadItemDoneEvent) # Check delta events + assert isinstance(events[1].update, AssistantMessageContentPartTextDelta) + assert isinstance(events[2].update, AssistantMessageContentPartTextDelta) assert events[1].update.delta == "Hello " assert events[2].update.delta == "world!" # Check final accumulated text final_message_event = events[-1] assert isinstance(final_message_event, ThreadItemDoneEvent) + assert isinstance(final_message_event.item, AssistantMessageItem) + assert isinstance(final_message_event.item.content[0], AssistantMessageContent) assert final_message_event.item.content[0].text == "Hello world!" async def test_stream_with_custom_id_generator(self): @@ -101,6 +112,7 @@ async def single_update_stream(): # Check that custom IDs are used message_added_event = events[0] + assert isinstance(message_added_event, ThreadItemAddedEvent) assert message_added_event.item.id == "custom_msg_123" async def test_stream_empty_content_updates(self): @@ -120,6 +132,7 @@ async def empty_content_stream(): assert isinstance(events[1], ThreadItemDoneEvent) # Final message should have empty content + assert isinstance(events[1].item, AssistantMessageItem) assert len(events[1].item.content) == 0 async def test_stream_non_text_content(self): diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index c2ac097ae53..5b88d0cf5e1 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -42,9 +42,9 @@ from claude_agent_sdk.types import StreamEvent, TextBlock if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 11): from typing import TypedDict # pragma: no cover else: @@ -392,7 +392,7 @@ def _normalize_tools( non_builtin_tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] = [] if not isinstance(tools, list): - tools = [tools] # type: ignore[assignment, reportUnknownVariableType] + tools = [tools] for tool in tools: # type: ignore[reportUnknownVariableType] if isinstance(tool, str): self._builtin_tools.append(tool) @@ -400,7 +400,7 @@ def _normalize_tools( non_builtin_tools.append(tool) # type: ignore[union-attr, reportUnknownArgumentType] if not non_builtin_tools: return - self._custom_tools.extend(normalize_tools(non_builtin_tools)) # type: ignore[reportUnknownVariableType] + self._custom_tools.extend(normalize_tools(non_builtin_tools)) async def __aenter__(self) -> RawClaudeAgent[OptionsT]: """Start the agent when entering async context.""" @@ -689,7 +689,7 @@ def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentRes return AgentResponse.from_updates(updates, value=structured_output) @overload - def run( # type: ignore[override] + def run( self, messages: AgentRunInputs | None = None, *, @@ -700,7 +700,7 @@ def run( # type: ignore[override] ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run( # type: ignore[override] + def run( self, messages: AgentRunInputs | None = None, *, @@ -717,7 +717,7 @@ def run( stream: bool = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, # type: ignore + **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. @@ -853,7 +853,7 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options print(response.text) """ - @overload # type: ignore[override] + @overload def run( self, messages: AgentRunInputs | None = None, @@ -870,7 +870,7 @@ def run( **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... - @overload # type: ignore[override] + @overload def run( self, messages: AgentRunInputs | None = None, @@ -887,7 +887,7 @@ def run( **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( # pyright: ignore[reportIncompatibleMethodOverride] # type: ignore[override] + def run( # pyright: ignore[reportIncompatibleMethodOverride] self, messages: AgentRunInputs | None = None, *, diff --git a/python/packages/claude/tests/test_claude_agent.py b/python/packages/claude/tests/test_claude_agent.py index 87816a139c3..11f1f70bbba 100644 --- a/python/packages/claude/tests/test_claude_agent.py +++ b/python/packages/claude/tests/test_claude_agent.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -567,10 +567,11 @@ def create_person(person: Person) -> str: # Verify $defs is preserved in the schema assert sdk_tool.input_schema is not None - assert "$defs" in sdk_tool.input_schema # type: ignore[operator] - assert "Address" in sdk_tool.input_schema["$defs"] # type: ignore[index] + input_schema = cast(dict[str, Any], sdk_tool.input_schema) + assert "$defs" in input_schema + assert "Address" in input_schema["$defs"] # Verify the nested reference exists in properties - assert "person" in sdk_tool.input_schema["properties"] # type: ignore[index] + assert "person" in input_schema["properties"] async def test_tool_handler_success(self) -> None: """Test tool handler executes successfully.""" diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_acquire_token.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_acquire_token.py index ef7cd728c6c..ef11a8b046d 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_acquire_token.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_acquire_token.py @@ -62,7 +62,7 @@ def acquire_token( logger.debug("Attempting silent token acquisition") response = pca.acquire_token_silent(scopes=target_scopes, account=accounts[0]) if response and "access_token" in response: - token = str(response["access_token"]) # type: ignore[assignment] + token = str(response["access_token"]) logger.debug("Successfully acquired token silently") elif response and "error" in response: logger.warning( @@ -77,7 +77,7 @@ def acquire_token( logger.debug("Attempting interactive token acquisition") response = pca.acquire_token_interactive(scopes=target_scopes) if response and "access_token" in response: - token = str(response["access_token"]) # type: ignore[assignment] + token = str(response["access_token"]) logger.debug("Successfully acquired token interactively") elif response and "error" in response: logger.error( diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 7e2376f71a1..f9a7a0c2205 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -56,13 +56,13 @@ from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - pass # type: ignore # pragma: no cover + pass else: - pass # type: ignore[import] # pragma: no cover + pass # pragma: no cover if sys.version_info >= (3, 11): from typing import Self, TypedDict # pragma: no cover else: @@ -582,7 +582,7 @@ async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: # region Agent -class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc] +class RawAgent(BaseAgent, Generic[OptionsCoT]): """A Chat Client Agent without middleware or telemetry layers. This is the core chat agent implementation. For most use cases, @@ -1128,7 +1128,7 @@ def _finalize_response_updates( response_format: Any | None = None, ) -> AgentResponse[Any]: """Finalize response updates into a single AgentResponse.""" - return AgentResponse.from_updates( # pyright: ignore[reportUnknownVariableType] + return AgentResponse.from_updates( updates, output_format_type=response_format, ) @@ -1269,7 +1269,7 @@ async def _prepare_run_context( duplicate_error_message=mcp_duplicate_message, ) else: - _append_unique_tools(final_tools, [tool]) # type: ignore[list-item] + _append_unique_tools(final_tools, [tool]) for mcp_server in self.mcp_tools: if not mcp_server.is_connected: @@ -1477,7 +1477,7 @@ async def _prepare_session_and_messages( if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") await provider.before_run( - agent=self, # type: ignore[arg-type] + agent=self, session=provider_session, context=session_context, state=provider_session.state.setdefault(provider.source_id, {}), @@ -1543,7 +1543,7 @@ def as_mcp_server( if kwargs: server_args.update(kwargs) - server: Server[Any] = Server(**server_args) # type: ignore[call-arg] + server: Server[Any] = Server(**server_args) agent_tool = self.as_tool(name=self._get_agent_name()) @@ -1557,7 +1557,7 @@ async def _log(level: types.LoggingLevel, data: Any) -> None: except Exception as e: logger.error("Failed to send log message to server: %s", e) - @server.list_tools() # type: ignore + @server.list_tools() async def _list_tools() -> list[types.Tool]: # type: ignore """List all tools in the agent.""" schema = agent_tool.parameters() @@ -1571,7 +1571,7 @@ async def _list_tools() -> list[types.Tool]: # type: ignore await _log(level="debug", data=f"Agent tool: {agent_tool}") return [tool] - @server.call_tool() # type: ignore + @server.call_tool() async def _call_tool( # type: ignore name: str, arguments: dict[str, Any] ) -> Sequence[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]: @@ -1603,18 +1603,18 @@ async def _call_tool( # type: ignore # Convert result to MCP content. # Currently only text items are forwarded over MCP; rich content # (images, audio) is not yet supported in the MCP server path. - mcp_content: list[types.TextContent | types.ImageContent | types.EmbeddedResource] = [] # type: ignore[attr-defined] + mcp_content: list[types.TextContent | types.ImageContent | types.EmbeddedResource] = [] for c in result: if c.type == "text" and c.text: - mcp_content.append(types.TextContent(type="text", text=c.text)) # type: ignore[attr-defined] + mcp_content.append(types.TextContent(type="text", text=c.text)) elif c.type in ("data", "uri"): logger.warning( "MCP server does not yet forward rich content (images, audio) " "in tool results. Rich content items will be omitted." ) - return mcp_content or [types.TextContent(type="text", text="")] # type: ignore[attr-defined] + return mcp_content or [types.TextContent(type="text", text="")] - @server.set_logging_level() # type: ignore + @server.set_logging_level() async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore """Set the logging level for the server.""" logger.setLevel(LOG_LEVEL_MAPPING[level]) @@ -1712,9 +1712,9 @@ def run( """Run the agent.""" super_run = cast( "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", - super().run, # type: ignore[misc] + super().run, ) - return super_run( # type: ignore[no-any-return] + return super_run( messages=messages, stream=stream, session=session, diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 746427bffda..69a4d44ef92 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -46,9 +46,9 @@ ) if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if TYPE_CHECKING: @@ -346,7 +346,7 @@ def _finalize_response_updates( response_format: Any | None = None, ) -> ChatResponse[Any]: """Finalize response updates into a single ChatResponse.""" - return ChatResponse.from_updates( # pyright: ignore[reportUnknownVariableType] + return ChatResponse.from_updates( updates, output_format_type=response_format, ) @@ -517,7 +517,7 @@ def get_response( return self._inner_get_response( messages=messages, stream=stream, - options=options or {}, # type: ignore[arg-type] + options=options or {}, **merged_client_kwargs, ) diff --git a/python/packages/core/agent_framework/_evaluation.py b/python/packages/core/agent_framework/_evaluation.py index 52bdf90d0fc..19e6c0f12b8 100644 --- a/python/packages/core/agent_framework/_evaluation.py +++ b/python/packages/core/agent_framework/_evaluation.py @@ -969,7 +969,7 @@ def _extract_agent_eval_data( agent_exec_response: AgentExecutorResponse | None = None if isinstance(completion_data, list): - for cdata_item in cast(list[Any], completion_data): # type: ignore[redundant-cast] + for cdata_item in cast(list[Any], completion_data): if isinstance(cdata_item, AgentExecutorResponse): agent_exec_response = cdata_item break @@ -982,7 +982,7 @@ def _extract_agent_eval_data( query: str | list[Message] if agent_exec_response.full_conversation: user_msgs = [m for m in agent_exec_response.full_conversation if m.role == "user"] - query = user_msgs or agent_exec_response.full_conversation # type: ignore[assignment] + query = user_msgs or agent_exec_response.full_conversation elif executor_id in invoked_data: input_data: Any = invoked_data[executor_id] query = ( # type: ignore[assignment] @@ -1017,7 +1017,7 @@ def _extract_overall_query(workflow_result: WorkflowRunResult) -> str | None: if isinstance(data, str): return data if isinstance(data, list) and data: - items_list = cast(list[Any], data) # type: ignore[redundant-cast] + items_list = cast(list[Any], data) first = items_list[0] if isinstance(first, Message): msgs: list[Message] = [m for m in items_list if isinstance(m, Message)] @@ -1492,7 +1492,7 @@ async def _check(item: EvalItem) -> CheckResult: result = await result return _coerce_result(value=result, check_name=check_name) - _check.__name__ = check_name # type: ignore[attr-defined,assignment] + _check.__name__ = check_name _check.__doc__ = func.__doc__ return _check @@ -2040,7 +2040,7 @@ def _build_overall_item( final_output: Any = outputs[-1] overall_response: AgentResponse[None] if isinstance(final_output, list) and final_output and isinstance(final_output[0], Message): - msgs: list[Message] = [m for m in cast(list[Any], final_output) if isinstance(m, Message)] # type: ignore[redundant-cast] + msgs: list[Message] = [m for m in cast(list[Any], final_output) if isinstance(m, Message)] response_text = " ".join(str(m.text) for m in msgs if m.role == "assistant") overall_response = AgentResponse(messages=[Message("assistant", [response_text])]) elif isinstance(final_output, AgentResponse): diff --git a/python/packages/core/agent_framework/_feature_stage.py b/python/packages/core/agent_framework/_feature_stage.py index f757258fc9c..87392713cc6 100644 --- a/python/packages/core/agent_framework/_feature_stage.py +++ b/python/packages/core/agent_framework/_feature_stage.py @@ -279,7 +279,7 @@ def __new__(cls: type[Any], /, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"{cls.__name__}() takes no arguments") return original_new(cls) - experimental_class.__new__ = staticmethod(__new__) # type: ignore[assignment] + experimental_class.__new__ = staticmethod(__new__) original_init_subclass: Any = experimental_class.__init_subclass__ if isinstance(original_init_subclass, MethodType): @@ -308,7 +308,7 @@ def init_subclass_wrapper(*args: Any, **kwargs: Any) -> Any: ) return original_init_subclass(*args, **kwargs) - experimental_class.__init_subclass__ = init_subclass_wrapper # type: ignore[assignment] + experimental_class.__init_subclass__ = init_subclass_wrapper return cast(FeatureStageT, experimental_class) diff --git a/python/packages/core/agent_framework/_harness/_background_agents.py b/python/packages/core/agent_framework/_harness/_background_agents.py index c5efa1a6fbf..4bd7d65975a 100644 --- a/python/packages/core/agent_framework/_harness/_background_agents.py +++ b/python/packages/core/agent_framework/_harness/_background_agents.py @@ -111,12 +111,8 @@ def from_dict(cls, data: MutableMapping[str, Any], **kwargs: Any) -> BackgroundT class _RuntimeState: """Non-serializable per-session runtime state for background tasks.""" - in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field( - default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType] - ) - background_sessions: dict[int, AgentSession] = field( - default_factory=lambda: {} # pyright: ignore[reportUnknownLambdaType] - ) + in_flight_tasks: dict[int, asyncio.Task[AgentResponse[Any]]] = field(default_factory=lambda: {}) + background_sessions: dict[int, AgentSession] = field(default_factory=lambda: {}) # --------------------------------------------------------------------------- diff --git a/python/packages/core/agent_framework/_harness/_memory.py b/python/packages/core/agent_framework/_harness/_memory.py index 92e060f4428..c1be1d33527 100644 --- a/python/packages/core/agent_framework/_harness/_memory.py +++ b/python/packages/core/agent_framework/_harness/_memory.py @@ -1066,7 +1066,7 @@ def _chat_client( return override client: object = getattr(agent, "client", None) if isinstance(client, SupportsChatGetResponse): - return cast(SupportsChatGetResponse[Any], client) # type: ignore[redundant-cast] + return cast(SupportsChatGetResponse[Any], client) return None @staticmethod diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 784c618302d..11fbd64c60c 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -291,7 +291,7 @@ def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextM f"The optional dependency `{missing_name}` is not installed. Please update your dependencies." ) from ex - return _streamable_http_client(*args, **kwargs) # type: ignore[return-value] + return _streamable_http_client(*args, **kwargs) def _should_propagate_cancelled_error(ex: BaseException) -> bool: @@ -567,7 +567,7 @@ def _parse_content_from_mcp( mcp_content_types: Sequence[Any] = ( cast(Sequence[Any], mcp_type) if isinstance(mcp_type, Sequence) else [mcp_type] - ) # type: ignore[redundant-cast] + ) return_types: list[Content] = [] for mcp_type in mcp_content_types: match mcp_type: @@ -606,7 +606,7 @@ def _parse_content_from_mcp( result=self._parse_content_from_mcp(mcp_type.content) if mcp_type.content else mcp_type.structuredContent, - exception=str(Exception()) if mcp_type.isError else None, # type: ignore[arg-type] + exception=str(Exception()) if mcp_type.isError else None, raw_representation=mcp_type, ) ) @@ -649,16 +649,16 @@ def _prepare_content_for_mcp( if content.type == "text": return types.TextContent(type="text", text=content.text) # type: ignore[attr-defined] if content.type == "data": - if content.media_type and content.media_type.startswith("image/"): # type: ignore[attr-defined] + if content.media_type and content.media_type.startswith("image/"): return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined] - if content.media_type and content.media_type.startswith("audio/"): # type: ignore[attr-defined] + if content.media_type and content.media_type.startswith("audio/"): return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined] - if content.media_type and content.media_type.startswith("application/"): # type: ignore[attr-defined] + if content.media_type and content.media_type.startswith("application/"): return types.EmbeddedResource( type="resource", resource=types.BlobResourceContents( blob=content.uri, # type: ignore[attr-defined] - mimeType=content.media_type, # type: ignore[attr-defined] + mimeType=content.media_type, uri=( content.additional_properties.get("uri", "af://binary") if content.additional_properties @@ -674,7 +674,7 @@ def _prepare_content_for_mcp( return types.ResourceLink( type="resource_link", uri=content.uri, # type: ignore[arg-type,attr-defined] - mimeType=content.media_type, # type: ignore[attr-defined] + mimeType=content.media_type, name=resource_name, ) return None @@ -1535,7 +1535,7 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: OtelAttr.TOOL_NAME: tool_name, OtelAttr.OPERATION: OtelAttr.TOOL_EXECUTION_OPERATION, }) - with create_mcp_client_span("tools/call", target=tool_name, attributes=mcp_span_attrs) as span: # type: ignore + with create_mcp_client_span("tools/call", target=tool_name, attributes=mcp_span_attrs) as span: return await self._call_tool_with_retries(tool_name, filtered_kwargs, meta, parser, span) async def _call_tool_with_retries( @@ -1785,7 +1785,7 @@ async def _call_tool_as_task_create( name=tool_name, arguments=arguments, task=task_metadata, - _meta=request_meta, # type: ignore[call-arg] + _meta=request_meta, ) request = types.ClientRequest(types.CallToolRequest(params=params)) @@ -2517,8 +2517,8 @@ async def _inject_headers(request: Request) -> None: # noqa: RUF029 for key, value in headers.items(): request.headers[key] = value - self._inject_headers_hook = _inject_headers # type: ignore[attr-defined] - http_client.event_hooks["request"].append(self._inject_headers_hook) # type: ignore[attr-defined] + self._inject_headers_hook = _inject_headers + http_client.event_hooks["request"].append(self._inject_headers_hook) return streamable_http_client( url=self.url, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 2eddb08a070..5315575b6af 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -25,13 +25,13 @@ from .exceptions import MiddlewareException if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from pydantic import BaseModel @@ -362,16 +362,12 @@ def remove_tools( names_to_remove.add(item) continue for normalized in normalize_tools(item): - if name := _get_tool_name(normalized): # type: ignore[reportPrivateUsage] + if name := _get_tool_name(normalized): names_to_remove.add(name) if not names_to_remove: return - self.tools[:] = [ - tool - for tool in self.tools - if _get_tool_name(tool) not in names_to_remove # type: ignore[reportPrivateUsage] - ] + self.tools[:] = [tool for tool in self.tools if _get_tool_name(tool) not in names_to_remove] class ChatContext: @@ -1190,7 +1186,7 @@ def get_response( context_kwargs["compaction_strategy"] = compaction_strategy if tokenizer is not None: context_kwargs["tokenizer"] = tokenizer - pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType] + pipeline = self._get_chat_middleware_pipeline(call_middleware) if not pipeline.has_middlewares: return super_get_response( # type: ignore[no-any-return] messages=messages, diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index d9fdd618f7d..edf9757ac69 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -606,13 +606,13 @@ def _get_type_identifier(cls, value: Mapping[str, Any] | None = None) -> str: """ # for from_dict if value and (type_ := value.get("type")) and isinstance(type_, str): - return type_ # type:ignore[no-any-return] + return type_ # for todict when defined per instance if (type_ := getattr(cls, "type", None)) and isinstance(type_, str): - return type_ # type:ignore[no-any-return] + return type_ # for both when defined on class. if (type_ := getattr(cls, "TYPE", None)) and isinstance(type_, str): - return type_ # type:ignore[no-any-return] + return type_ # Fallback and default # Convert class name to snake_case return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower() @@ -636,7 +636,7 @@ def make_json_safe(obj: Any) -> Any: if isinstance(obj, (datetime, date)): return obj.isoformat() if is_dataclass(obj) and not isinstance(obj, type): - return make_json_safe(asdict(obj)) # type: ignore[arg-type] + return make_json_safe(asdict(obj)) if callable(getattr(obj, "model_dump", None)): try: return make_json_safe(obj.model_dump()) # type: ignore[no-any-return] @@ -657,5 +657,5 @@ def make_json_safe(obj: Any) -> Any: if isinstance(obj, (list, tuple)): return [make_json_safe(item) for item in obj] # type: ignore[misc] if hasattr(obj, "__dict__"): - return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc] + return {key: make_json_safe(value) for key, value in vars(obj).items()} return str(obj) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 32d16995339..64f30083fdb 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -96,7 +96,7 @@ def register_state_type(cls: type) -> None: def _serialize_value(value: Any) -> Any: """Serialize a single value, handling objects with to_dict() and Pydantic models.""" if hasattr(value, "to_dict") and callable(value.to_dict): - return value.to_dict() # pyright: ignore[reportUnknownMemberType] + return value.to_dict() # Pydantic BaseModel support — import lazily to avoid hard dep at module level with suppress(ImportError): from pydantic import BaseModel diff --git a/python/packages/core/agent_framework/_settings.py b/python/packages/core/agent_framework/_settings.py index 9f9ab29a391..c2d0b359038 100644 --- a/python/packages/core/agent_framework/_settings.py +++ b/python/packages/core/agent_framework/_settings.py @@ -41,9 +41,9 @@ class MySettings(TypedDict, total=False): from .exceptions import SettingNotFoundError if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover SettingsT = TypeVar("SettingsT", default=dict[str, Any]) diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index 97afe66cea9..91bdb619143 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -3516,9 +3516,7 @@ async def get_content(self) -> str: result = await self._client.read_resource(_mcp_any_url(self._skill_md_uri)) text = _mcp_join_text(result) if not text: - raise ValueError( - f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'." - ) + raise ValueError(f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'.") self._content = text return text @@ -3572,11 +3570,7 @@ def _validate_resource_name(name: str) -> str | None: or ``None`` if the name is unsafe. """ normalized = name.replace("\\", "/") - if ( - normalized.startswith("/") - or "://" in normalized - or any(seg == ".." for seg in normalized.split("/")) - ): + if normalized.startswith("/") or "://" in normalized or any(seg == ".." for seg in normalized.split("/")): logger.debug("Rejecting resource name with unsafe path components: %r", name) return None return normalized diff --git a/python/packages/core/agent_framework/_telemetry.py b/python/packages/core/agent_framework/_telemetry.py index ab8576a305e..bce78f49015 100644 --- a/python/packages/core/agent_framework/_telemetry.py +++ b/python/packages/core/agent_framework/_telemetry.py @@ -18,14 +18,14 @@ APP_INFO = ( { - "agent-framework-version": f"python/{version_info}", # type: ignore[has-type] + "agent-framework-version": f"python/{version_info}", } if IS_TELEMETRY_ENABLED else None ) USER_AGENT_KEY: Final[str] = "User-Agent" HTTP_USER_AGENT: Final[str] = "agent-framework-python" -AGENT_FRAMEWORK_USER_AGENT = f"{HTTP_USER_AGENT}/{version_info}" # type: ignore[has-type] +AGENT_FRAMEWORK_USER_AGENT = f"{HTTP_USER_AGENT}/{version_info}" # This environment variable is reserved by the Foundry hosting environment to # indicate that the agent is running in a hosted environment. @@ -79,11 +79,11 @@ def _detect_hosted_environment() -> None: except (ModuleNotFoundError, ValueError): return with contextlib.suppress(ImportError, AttributeError): - from azure.ai.agentserver.core import ( # pyright: ignore[reportMissingImports] - AgentConfig, # pyright: ignore[reportUnknownVariableType] + from azure.ai.agentserver.core import ( + AgentConfig, ) - if AgentConfig.from_env().is_hosted: # pyright: ignore[reportUnknownMemberType] + if AgentConfig.from_env().is_hosted: _add_user_agent_prefix(_HOSTED_USER_AGENT_PREFIX) _hosted_env_detected = True diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 7bb54ee2c93..d46655a9174 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -50,13 +50,13 @@ ) if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore[import] # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # pragma: no cover if TYPE_CHECKING: @@ -203,7 +203,7 @@ def _default_histogram() -> Histogram: """ from .observability import OBSERVABILITY_SETTINGS # local import to avoid circulars - if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] + if not OBSERVABILITY_SETTINGS.ENABLED: return NoOpHistogram( name=OtelAttr.MEASUREMENT_FUNCTION_INVOCATION_DURATION, unit=OtelAttr.DURATION_UNIT, @@ -688,7 +688,7 @@ async def invoke( if self._context_parameter_name is not None and effective_context is not None: call_kwargs[self._context_parameter_name] = effective_context - if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] + if not OBSERVABILITY_SETTINGS.ENABLED: logger.info(f"Function name: {self.name}") logger.debug(f"Function arguments: {observable_kwargs}") result = await self._invoke_function(call_kwargs) @@ -727,7 +727,7 @@ async def invoke( "response_format", } } - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: attributes.update({ OtelAttr.TOOL_ARGUMENTS: ( json.dumps(serializable_kwargs, default=str, ensure_ascii=False) if serializable_kwargs else "None" @@ -736,7 +736,7 @@ async def invoke( with get_function_span(attributes=attributes) as span: attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name logger.info(f"Function name: {self.name}") - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: logger.debug(f"Function arguments: {serializable_kwargs}") start_time_stamp = perf_counter() end_time_stamp: float | None = None @@ -752,7 +752,7 @@ async def invoke( else: if skip_parsing: logger.info(f"Function {self.name} succeeded.") - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: result_str = str(result) span.set_attribute(OtelAttr.TOOL_RESULT, result_str) logger.debug(f"Function result: {result_str}") @@ -765,7 +765,7 @@ async def invoke( if isinstance(parsed, str): parsed = [Content.from_text(parsed)] logger.info(f"Function {self.name} succeeded.") - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: result_str = "\n".join(c.text or "" for c in parsed if c.type == "text") or str(parsed) span.set_attribute(OtelAttr.TOOL_RESULT, result_str) logger.debug(f"Function result: {result_str}") @@ -854,8 +854,8 @@ def parse_result(result: Any) -> list[Content]: if isinstance(item, Content): parsed_items.append(item) else: - dumpable = FunctionTool._make_dumpable(item) # type: ignore[reportUnknownArgumentType] - text = dumpable if isinstance(dumpable, str) else json.dumps(dumpable, default=str) # type: ignore[reportUnknownArgumentType] + dumpable = FunctionTool._make_dumpable(item) + text = dumpable if isinstance(dumpable, str) else json.dumps(dumpable, default=str) parsed_items.append(Content.from_text(text)) return parsed_items dumpable = FunctionTool._make_dumpable(result) @@ -1115,13 +1115,13 @@ def _validate_arguments_against_schema( if not isinstance(properties.get(field_name), dict): continue - enum_values = properties.get(field_name, {}).get("enum") # type: ignore + enum_values = properties.get(field_name, {}).get("enum") if isinstance(enum_values, list) and enum_values and field_value not in enum_values: raise TypeError( f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}" ) - schema_type = properties.get(field_name, {}).get("type") # type: ignore + schema_type = properties.get(field_name, {}).get("type") if isinstance(schema_type, str): if not _matches_json_schema_type(field_value, schema_type): raise TypeError( @@ -1315,7 +1315,7 @@ def get_weather(location: str, unit: str = "celsius") -> str: def decorator(func: Callable[..., Any]) -> FunctionTool: @wraps(func) def wrapper(f: Callable[..., Any]) -> FunctionTool: - tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment] + tool_name: str = name or getattr(f, "__name__", "unknown_function") tool_desc: str = description or (f.__doc__ or "") return FunctionTool( name=tool_name, @@ -1471,13 +1471,13 @@ async def _auto_invoke_function( return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=f'Error: Requested function "{function_call_content.name}" not found.', - exception=str(exc), # type: ignore[arg-type] + exception=str(exc), additional_properties=function_call_content.additional_properties, ) else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results # and never reach this function, so we only handle approved=True cases here. - approved_function_call = function_call_content.function_call # type: ignore[attr-defined] + approved_function_call = function_call_content.function_call if ( approved_function_call is None or approved_function_call.type != "function_call" @@ -1520,7 +1520,7 @@ async def _auto_invoke_function( return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=message, - exception=str(exc), # type: ignore[arg-type] + exception=str(exc), additional_properties=function_call_content.additional_properties, ) @@ -1637,7 +1637,7 @@ async def final_function_handler(context_obj: Any) -> Any: return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=message, - exception=str(exc), # type: ignore[arg-type] + exception=str(exc), additional_properties=function_call_content.additional_properties, ) @@ -1706,17 +1706,15 @@ async def _try_execute_function_calls( fcc_name, fcc_name in approval_tools, ) - if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined] + if fcc.type == "function_call" and fcc.name in approval_tools: logger.debug("Approval needed for function: %s", fcc.name) approval_needed = True break - if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): # type: ignore[attr-defined] + if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): declaration_only_flag = True break - if ( - config.get("terminate_on_unknown_calls", False) and fcc.type == "function_call" and fcc.name not in tool_map # type: ignore[attr-defined] - ): - raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] + if config.get("terminate_on_unknown_calls", False) and fcc.type == "function_call" and fcc.name not in tool_map: + raise KeyError(f'Error: Requested function "{fcc.name}" not found.') if approval_needed: # approval can only be needed for Function Call Content, not Approval Responses. logger.debug("Returning function_approval_request contents") @@ -1751,7 +1749,7 @@ async def invoke_with_termination_handling( """Invoke function and catch MiddlewareTermination, returning (result, should_terminate).""" try: result = await _auto_invoke_function( - function_call_content=function_call, # type: ignore[arg-type] + function_call_content=function_call, custom_args=custom_args, tool_map=tool_map, invocation_session=invocation_session, @@ -1777,9 +1775,9 @@ async def invoke_with_termination_handling( propagated: list[Content] = [] for item in exc.contents: if isinstance(item, Content): - item.call_id = function_call.call_id # type: ignore[attr-defined] - if not item.id: # type: ignore[attr-defined] - item.id = function_call.call_id # type: ignore[attr-defined] + item.call_id = function_call.call_id + if not item.id: + item.id = function_call.call_id propagated.append(item) if propagated: extra_user_input_contents.extend(propagated[1:]) @@ -1822,7 +1820,7 @@ async def _execute_function_calls( custom_args=custom_args, attempt_idx=attempt_idx, function_calls=function_calls, - tools=tools, # type: ignore + tools=tools, invocation_session=invocation_session, middleware_pipeline=middleware_pipeline, config=config, @@ -1957,9 +1955,7 @@ def _replace_approval_contents_with_results( for msg in messages: # First pass - collect existing function call IDs to avoid duplicates existing_call_ids = { - content.call_id # type: ignore[union-attr, operator] - for content in msg.contents - if content.type == "function_call" and content.call_id # type: ignore[attr-defined] + content.call_id for content in msg.contents if content.type == "function_call" and content.call_id } # Track approval requests that should be removed (duplicates) @@ -2000,7 +1996,7 @@ def _replace_approval_contents_with_results( # Create a "not approved" result for rejected calls # Use function_call.call_id (the function's ID), not content.id (approval's ID) msg.contents[content_idx] = Content.from_function_result( - call_id=content.function_call.call_id, # type: ignore[union-attr, arg-type] + call_id=content.function_call.call_id, result="Error: Tool call invocation was rejected by user.", ) msg.role = "tool" @@ -2368,7 +2364,7 @@ def get_response( additional_function_arguments = ( dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} ) - if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] + if options and (additional_opts := options.get("additional_function_arguments")): additional_function_arguments.update(cast(Mapping[str, Any], additional_opts)) from ._sessions import AgentSession as _AgentSession @@ -2424,7 +2420,7 @@ async def _get_response() -> ChatResponse[Any]: approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, - tool_options=mutable_options, # type: ignore[arg-type] + tool_options=mutable_options, attempt_idx=attempt_idx, fcc_messages=None, errors_in_a_row=errors_in_a_row, @@ -2462,7 +2458,7 @@ async def _get_response() -> ChatResponse[Any]: result = await _process_function_requests( response=response, prepped_messages=None, - tool_options=mutable_options, # type: ignore[arg-type] + tool_options=mutable_options, attempt_idx=attempt_idx, fcc_messages=fcc_messages, errors_in_a_row=errors_in_a_row, @@ -2561,7 +2557,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, - tool_options=mutable_options, # type: ignore[arg-type] + tool_options=mutable_options, attempt_idx=attempt_idx, fcc_messages=None, errors_in_a_row=errors_in_a_row, @@ -2616,7 +2612,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: result = await _process_function_requests( response=response, prepped_messages=None, - tool_options=mutable_options, # type: ignore[arg-type] + tool_options=mutable_options, attempt_idx=attempt_idx, fcc_messages=fcc_messages, errors_in_a_row=errors_in_a_row, diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index f30fc04789d..3dddb4e82eb 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -191,7 +191,7 @@ def _get_data_bytes_as_str(content: Content) -> str | None: raise ContentError("Data URI must use base64 encoding") _, data = uri.split(";base64,", 1) - return data # type: ignore[return-value, no-any-return] + return data def _get_data_bytes(content: Content) -> bytes | None: # pyright: ignore[reportUnusedFunction] @@ -390,7 +390,7 @@ class Annotation(TypedDict, total=False): # endregion -class UsageDetails(TypedDict, total=False, extra_items=int): # type: ignore[call-arg] +class UsageDetails(TypedDict, total=False, extra_items=int): """A dictionary representing usage details. This is a non-closed dictionary, so any specific provider fields can be added as needed. @@ -445,7 +445,7 @@ def add_usage_details(usage1: UsageDetails | None, usage2: UsageDetails | None) ): logger.warning("Non `int` value found in usage details, skipping.") continue - result[key] = (val1 or 0) + (val2 or 0) # type: ignore[literal-required] + result[key] = (val1 or 0) + (val2 or 0) return result @@ -1442,12 +1442,12 @@ def _add_text_reasoning_content(self, other: Content) -> Content: combined_id = self.id or other.id # Concatenate text, handling None values - self_text = self.text or "" # type: ignore[attr-defined] - other_text = other.text or "" # type: ignore[attr-defined] + self_text = self.text or "" + other_text = other.text or "" combined_text = self_text + other_text if (self_text or other_text) else None # Handle protected_data replacement - protected_data = other.protected_data if other.protected_data is not None else self.protected_data # type: ignore[attr-defined] + protected_data = other.protected_data if other.protected_data is not None else self.protected_data return Content( "text_reasoning", @@ -1901,14 +1901,14 @@ def _process_update(response: ChatResponse | AgentResponse, update: ChatResponse # mypy doesn't narrow type based on match/case, but we know these are FunctionCallContents case "function_call" if message.contents and message.contents[-1].type == "function_call": try: - message.contents[-1] += content # type: ignore[operator] + message.contents[-1] += content except (AdditionItemMismatch, ContentError): message.contents.append(content) case "usage": if response.usage_details is None: response.usage_details = UsageDetails() # mypy doesn't narrow type based on match/case, but we know this is UsageContent - response.usage_details = add_usage_details(response.usage_details, content.usage_details) # type: ignore[arg-type] + response.usage_details = add_usage_details(response.usage_details, content.usage_details) case _: message.contents.append(content) # Incorporate the update's properties into the response. @@ -3553,7 +3553,7 @@ def my_tool(x: int) -> int: # Expand MCP tools to their constituent functions if not tool_.is_connected: await tool_.connect() - final_tools.extend(tool_.functions) # type: ignore + final_tools.extend(tool_.functions) else: final_tools.append(tool_) diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index f1d1507702a..f0b51969c07 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -39,9 +39,9 @@ from ._typing_utils import is_instance_of, is_type_compatible if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from ._workflow import Workflow @@ -276,7 +276,7 @@ async def _run_impl( if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") await provider.before_run( - agent=self, # type: ignore[arg-type] + agent=self, session=provider_session, context=session_context, state=provider_session.state.setdefault(provider.source_id, {}), @@ -356,7 +356,7 @@ async def _run_stream_impl( if provider_session is None: raise RuntimeError("Provider session must be available when context providers are configured.") await provider.before_run( - agent=self, # type: ignore[arg-type] + agent=self, session=provider_session, context=session_context, state=provider_session.state.setdefault(provider.source_id, {}), @@ -727,7 +727,7 @@ def _extract_function_responses(self, input_messages: Sequence[Message]) -> dict request_id: str = content.id # type: ignore[assignment] function_responses[request_id] = content elif content.type == "function_result": - response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined] + response_data = content.result if hasattr(content, "result") else str(content) function_responses[content.call_id] = response_data # type: ignore else: raise AgentInvalidResponseException( @@ -741,7 +741,7 @@ def _extract_contents(self, data: Any) -> list[Content]: if isinstance(data, list): return [c for item in data for c in self._extract_contents(item)] # type: ignore if isinstance(data, Content): - return [data] # type: ignore[redundant-cast] + return [data] if isinstance(data, str): return [Content.from_text(text=data)] return [Content.from_text(text=str(data))] @@ -835,7 +835,7 @@ def _add_raw(value: object) -> None: messages=(current.messages or []) + (incoming.messages or []), response_id=current.response_id or incoming.response_id, created_at=incoming.created_at or current.created_at, - usage_details=add_usage_details(current.usage_details, incoming.usage_details), # type: ignore[arg-type] + usage_details=add_usage_details(current.usage_details, incoming.usage_details), raw_representation=raw_list if raw_list else None, additional_properties=incoming.additional_properties or current.additional_properties, ) @@ -870,7 +870,7 @@ def _add_raw(value: object) -> None: if aggregated: final_messages.extend(aggregated.messages) if aggregated.usage_details: - merged_usage = add_usage_details(merged_usage, aggregated.usage_details) # type: ignore[arg-type] + merged_usage = add_usage_details(merged_usage, aggregated.usage_details) if aggregated.created_at and ( not latest_created_at or _parse_dt(aggregated.created_at) > _parse_dt(latest_created_at) ): @@ -894,7 +894,7 @@ def _add_raw(value: object) -> None: flattened = AgentResponse.from_updates(global_dangling) final_messages.extend(flattened.messages) if flattened.usage_details: - merged_usage = add_usage_details(merged_usage, flattened.usage_details) # type: ignore[arg-type] + merged_usage = add_usage_details(merged_usage, flattened.usage_details) if flattened.created_at and ( not latest_created_at or _parse_dt(flattened.created_at) > _parse_dt(latest_created_at) ): diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index d96929ce793..22e872be07c 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -22,9 +22,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint.py b/python/packages/core/agent_framework/_workflows/_checkpoint.py index 22b4a1ea24d..7473f403e65 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint.py @@ -345,7 +345,7 @@ async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: def _read() -> dict[str, Any]: with open(file_path) as f: - return json.load(f) # type: ignore[no-any-return] + return json.load(f) encoded_checkpoint = await asyncio.to_thread(_read) diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py index c66faae75e2..7937841ec52 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py @@ -134,7 +134,7 @@ def find_class(self, module: str, name: str) -> type: or module.startswith(_FRAMEWORK_MODULE_PREFIX) or module.startswith(_OPENAI_MODULE_PREFIX) ): - return super().find_class(module, name) # type: ignore[no-any-return] # nosec + return super().find_class(module, name) raise pickle.UnpicklingError( f"Checkpoint deserialization blocked for type '{type_key}'. " diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index b9dbd266ec5..79ca09d990e 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -438,16 +438,14 @@ def from_dict(cls, data: dict[str, Any]) -> EdgeGroup: target_cls = cls._TYPE_REGISTRY.get(group_type, EdgeGroup) edges = [Edge.from_dict(entry) for entry in data.get("edges", [])] - obj = target_cls.__new__(target_cls) # type: ignore[misc] + obj = target_cls.__new__(target_cls) EdgeGroup.__init__(obj, edges=edges, id=data.get("id"), type=group_type) # Handle FanOutEdgeGroup-specific attributes if isinstance(obj, FanOutEdgeGroup): - obj.selection_func_name = data.get("selection_func_name") # type: ignore[attr-defined] + obj.selection_func_name = data.get("selection_func_name") obj._selection_func = ( # type: ignore[attr-defined] - None - if obj.selection_func_name is None # type: ignore[attr-defined] - else _missing_callable(obj.selection_func_name) # type: ignore[attr-defined] + None if obj.selection_func_name is None else _missing_callable(obj.selection_func_name) ) obj._target_ids = [edge.target_id for edge in obj.edges] # type: ignore[attr-defined] @@ -461,7 +459,7 @@ def from_dict(cls, data: dict[str, Any]) -> EdgeGroup: restored_cases.append(SwitchCaseEdgeGroupDefault.from_dict(case_data)) else: restored_cases.append(SwitchCaseEdgeGroupCase.from_dict(case_data)) - obj.cases = restored_cases # type: ignore[attr-defined] + obj.cases = restored_cases obj._selection_func = _missing_callable("switch_case_selection") # type: ignore[attr-defined] return obj @@ -878,9 +876,9 @@ def selection_func(message: Any, targets: list[str]) -> list[str]: EdgeGroup.__init__(self, edges, id=id, type=self.__class__.__name__) # Initialize FanOutEdgeGroup-specific attributes - self._target_ids = list(target_ids) # type: ignore[attr-defined] - self._selection_func = selection_func # type: ignore[attr-defined] - self.selection_func_name = None # type: ignore[attr-defined] + self._target_ids = list(target_ids) + self._selection_func = selection_func + self.selection_func_name = None self.cases = list(cases) def to_dict(self) -> dict[str, Any]: diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index aa1a69954f1..a986decca10 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -16,9 +16,9 @@ from ._typing_utils import deserialize_type, serialize_type if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore[import] # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover DataT = TypeVar("DataT", default=Any) diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 0d46c0daa38..966670d8f31 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -136,19 +136,19 @@ def __init__( # Sync function with context - wrap to make async using thread pool async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Call the sync function with both parameters in a thread - return await asyncio.to_thread(func, message, ctx) # type: ignore + return await asyncio.to_thread(func, message, ctx) elif not self._has_context and self._is_async: # Async function without context - wrap to ignore context async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Call the async function with just the message - return await func(message) # type: ignore + return await func(message) else: # Sync function without context - wrap to make async and ignore context using thread pool async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Call the sync function with just the message in a thread - return await asyncio.to_thread(func, message) # type: ignore + return await asyncio.to_thread(func, message) # Now register our instance handler self._register_instance_handler( diff --git a/python/packages/core/agent_framework/_workflows/_functional.py b/python/packages/core/agent_framework/_workflows/_functional.py index 73c0815862e..098700bea55 100644 --- a/python/packages/core/agent_framework/_workflows/_functional.py +++ b/python/packages/core/agent_framework/_workflows/_functional.py @@ -55,7 +55,7 @@ WorkflowErrorDetails, WorkflowEvent, WorkflowRunState, - _framework_event_origin, # type: ignore[reportPrivateUsage] + _framework_event_origin, ) from ._workflow import WorkflowRunResult @@ -516,7 +516,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> R: # Dedicated bypass event so consumers can tell cache-hit replays # apart from fresh executions. await ctx.add_event(WorkflowEvent.executor_bypassed(self.name, cached)) - return cached # type: ignore[return-value, no-any-return] + return cached # Inject RunContext if the step function declares it call_args, call_kwargs = self._build_call_args_with_ctx(ctx, args, kwargs) diff --git a/python/packages/core/agent_framework/_workflows/_model_utils.py b/python/packages/core/agent_framework/_workflows/_model_utils.py index 0627d716a86..db493b7e224 100644 --- a/python/packages/core/agent_framework/_workflows/_model_utils.py +++ b/python/packages/core/agent_framework/_workflows/_model_utils.py @@ -20,10 +20,10 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls: type[ModelT], data: dict[str, Any]) -> ModelT: - return cls(**data) # type: ignore[arg-type] + return cls(**data) def clone(self, *, deep: bool = True) -> Self: - return copy.deepcopy(self) if deep else copy.copy(self) # type: ignore[return-value] + return copy.deepcopy(self) if deep else copy.copy(self) def to_json(self) -> str: import json diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 2e4901f4118..a6bbbe3c73d 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -326,7 +326,7 @@ async def drain_events(self) -> list[WorkflowEvent]: while True: try: events.append(self._event_queue.get_nowait()) - except asyncio.QueueEmpty: # type: ignore[attr-defined] + except asyncio.QueueEmpty: break return events diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 07b6d15bca4..24e5bed96c8 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -140,7 +140,7 @@ def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool: if origin in [list, set]: return isinstance(data, origin) and ( not args or all(any(is_instance_of(item, arg) for arg in args) for item in data) # type: ignore[misc] - ) # type: ignore + ) # Case 4: target_type is a tuple if origin is tuple: diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index d8aa4c80f5f..779bb289dee 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -259,9 +259,9 @@ def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) - for target_type in target_input_types: if isinstance(edge_group, FanInEdgeGroup): # If the edge is part of an edge group, the target expects a list of data types - if is_type_compatible(list[source_type], target_type): # type: ignore[valid-type] + if is_type_compatible(list[source_type], target_type): compatible = True - compatible_pairs.append((list[source_type], target_type)) # type: ignore[valid-type] + compatible_pairs.append((list[source_type], target_type)) else: if is_type_compatible(source_type, target_type): compatible = True diff --git a/python/packages/core/agent_framework/_workflows/_viz.py b/python/packages/core/agent_framework/_workflows/_viz.py index 54015b066c9..e71d917f7aa 100644 --- a/python/packages/core/agent_framework/_workflows/_viz.py +++ b/python/packages/core/agent_framework/_workflows/_viz.py @@ -98,7 +98,7 @@ def export( return temp_file.name try: - import graphviz # type: ignore + import graphviz except ImportError as e: raise ImportError( "viz extra is required for export. Install it with: pip install graphviz>=0.20.0 " @@ -109,7 +109,7 @@ def export( # Create a temporary graphviz Source object dot_content = self.to_digraph(include_internal_executors=include_internal_executors) - source = graphviz.Source(dot_content) # type: ignore[reportUnknownVariableType] + source = graphviz.Source(dot_content) try: if filename: @@ -131,7 +131,7 @@ def export( source.render(base_name, format=format, cleanup=True) # type: ignore return f"{base_name}.{format}" - except graphviz.backend.execute.ExecutableNotFound as e: # type: ignore + except graphviz.backend.execute.ExecutableNotFound as e: raise ImportError( "The graphviz executables are not found. The graphviz Python package is installed, but the " "graphviz executables (dot, neato, etc.) are not available on your system's PATH. " @@ -308,7 +308,7 @@ def _emit_sub_workflows_digraph( """Emit DOT subgraphs for any WorkflowExecutor instances found in the workflow.""" # Lazy import to avoid any potential import cycles try: - from ._workflow_executor import WorkflowExecutor # type: ignore + from ._workflow_executor import WorkflowExecutor except ImportError: # pragma: no cover - best-effort; if unavailable, skip subgraphs return @@ -408,7 +408,7 @@ def _emit_sub_workflows_mermaid( include_internal_executors: bool = False, ) -> None: try: - from ._workflow_executor import WorkflowExecutor # type: ignore + from ._workflow_executor import WorkflowExecutor except ImportError: # pragma: no cover return diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 942f3010f0e..83d53a988c7 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -36,9 +36,9 @@ ) if sys.version_info >= (3, 11): - from typing import Self # type: ignore # pragma: no cover + from typing import Self # pragma: no cover else: - from typing_extensions import Self # type: ignore # pragma: no cover + from typing_extensions import Self # pragma: no cover logger = logging.getLogger(__name__) @@ -182,7 +182,7 @@ def _add_executor(self, executor: Executor) -> str: # New executor self._executors[executor.id] = executor # Add an internal edge group for each unique executor - self._edge_groups.append(InternalEdgeGroup(executor.id)) # type: ignore[call-arg] + self._edge_groups.append(InternalEdgeGroup(executor.id)) return executor.id @@ -199,13 +199,13 @@ def _maybe_wrap_agent(self, candidate: Executor | SupportsAgentRun) -> Executor: An Executor instance, wrapping the agent if necessary. """ try: # Local import to avoid hard dependency at import time - from agent_framework import SupportsAgentRun # type: ignore + from agent_framework import SupportsAgentRun except Exception: # pragma: no cover - defensive - SupportsAgentRun = object # type: ignore + SupportsAgentRun = object if isinstance(candidate, Executor): # Already an executor return candidate - if isinstance(candidate, SupportsAgentRun): # type: ignore[arg-type] + if isinstance(candidate, SupportsAgentRun): # Reuse existing wrapper for the same agent instance if present agent_instance_id = str(id(candidate)) existing = self._agent_wrappers.get(agent_instance_id) @@ -329,7 +329,7 @@ async def validate(self, data: str, ctx: WorkflowContext) -> None: target_execs = [self._maybe_wrap_agent(t) for t in targets] source_id = self._add_executor(source_exec) target_ids = [self._add_executor(t) for t in target_execs] - self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) # type: ignore[call-arg] + self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids)) return self @@ -410,13 +410,13 @@ async def handle(self, result: Result, ctx: WorkflowContext) -> None: internal_cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = [] for case in cases: # Allow case targets to be agents - case.target = self._maybe_wrap_agent(case.target) # type: ignore[arg-type] + case.target = self._maybe_wrap_agent(case.target) self._add_executor(case.target) if isinstance(case, Default): internal_cases.append(SwitchCaseEdgeGroupDefault(target_id=case.target.id)) else: internal_cases.append(SwitchCaseEdgeGroupCase(condition=case.condition, target_id=case.target.id)) - self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) # type: ignore[call-arg] + self._edge_groups.append(SwitchCaseEdgeGroup(source_id, internal_cases)) return self @@ -502,7 +502,7 @@ def select_workers(task: Task, available: list[str]) -> list[str]: target_execs = [self._maybe_wrap_agent(t) for t in targets] source_id = self._add_executor(source_exec) target_ids = [self._add_executor(t) for t in target_execs] - self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) # type: ignore[call-arg] + self._edge_groups.append(FanOutEdgeGroup(source_id, target_ids, selection_func)) return self @@ -557,7 +557,7 @@ async def aggregate(self, results: list[str], ctx: WorkflowContext[Never, str]) target_exec = self._maybe_wrap_agent(target) source_ids = [self._add_executor(s) for s in source_execs] target_id = self._add_executor(target_exec) - self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) # type: ignore[call-arg] + self._edge_groups.append(FanInEdgeGroup(source_ids, target_id)) return self diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index bfc8601e5d6..74293d356b1 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -328,7 +328,7 @@ async def send_message(self, message: OutT, target_id: str | None = None) -> Non self._sent_messages.append(message) # Inject current trace context if tracing enabled - if OBSERVABILITY_SETTINGS.ENABLED and span and span.is_recording(): # type: ignore[name-defined] + if OBSERVABILITY_SETTINGS.ENABLED and span and span.is_recording(): trace_context: dict[str, str] = {} inject(trace_context) # Inject current trace context for message propagation diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index e1315334293..a77887ba053 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -26,9 +26,9 @@ from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index d5a54c21021..b96aab6b916 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -33,9 +33,9 @@ def __init__( if log_level is not None: logger.log(log_level, message, exc_info=inner_exception) if inner_exception: - super().__init__(message, inner_exception, *args) # type: ignore + super().__init__(message, inner_exception, *args) else: - super().__init__(message, *args) # type: ignore + super().__init__(message, *args) # region Agent Exceptions diff --git a/python/packages/core/agent_framework/foundry/__init__.pyi b/python/packages/core/agent_framework/foundry/__init__.pyi index 08a7fc1b88e..87972e69fcf 100644 --- a/python/packages/core/agent_framework/foundry/__init__.pyi +++ b/python/packages/core/agent_framework/foundry/__init__.pyi @@ -4,12 +4,12 @@ # Install the relevant packages for full type support. from agent_framework_anthropic import AnthropicFoundryClient, RawAnthropicFoundryClient -from agent_framework_azure_contentunderstanding import ( # pyright: ignore[reportMissingImports] - AnalysisSection, # pyright: ignore[reportUnknownVariableType] - ContentUnderstandingContextProvider, # pyright: ignore[reportUnknownVariableType] - DocumentStatus, # pyright: ignore[reportUnknownVariableType] - FileSearchBackend, # pyright: ignore[reportUnknownVariableType] - FileSearchConfig, # pyright: ignore[reportUnknownVariableType] +from agent_framework_azure_contentunderstanding import ( + AnalysisSection, + ContentUnderstandingContextProvider, + DocumentStatus, + FileSearchBackend, + FileSearchConfig, ) from agent_framework_foundry import ( FoundryAgent, diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index a36b1f6aae2..2def30902bb 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -35,9 +35,9 @@ from ._settings import load_settings if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if TYPE_CHECKING: # pragma: no cover from opentelemetry.sdk._logs.export import LogRecordExporter @@ -391,14 +391,14 @@ def _create_otlp_exporters( if protocol == "grpc": # Import all gRPC exporters try: - from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( # type: ignore[reportMissingImports] - OTLPLogExporter as GRPCLogExporter, # type: ignore[reportUnknownVariableType] + from opentelemetry.exporter.otlp.proto.grpc._log_exporter import ( + OTLPLogExporter as GRPCLogExporter, ) - from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( # type: ignore[reportMissingImports] - OTLPMetricExporter as GRPCMetricExporter, # type: ignore[reportUnknownVariableType] + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( + OTLPMetricExporter as GRPCMetricExporter, ) - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( # type: ignore[reportMissingImports] - OTLPSpanExporter as GRPCSpanExporter, # type: ignore[reportUnknownVariableType] + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter as GRPCSpanExporter, ) except ImportError as exc: raise ImportError( @@ -408,21 +408,21 @@ def _create_otlp_exporters( if actual_logs_endpoint: exporters.append( - GRPCLogExporter( # type: ignore[reportUnknownArgumentType] + GRPCLogExporter( endpoint=actual_logs_endpoint, headers=actual_logs_headers if actual_logs_headers else None, ) ) if actual_traces_endpoint: exporters.append( - GRPCSpanExporter( # type: ignore[reportUnknownArgumentType] + GRPCSpanExporter( endpoint=actual_traces_endpoint, headers=actual_traces_headers if actual_traces_headers else None, ) ) if actual_metrics_endpoint: exporters.append( - GRPCMetricExporter( # type: ignore[reportUnknownArgumentType] + GRPCMetricExporter( endpoint=actual_metrics_endpoint, headers=actual_metrics_headers if actual_metrics_headers else None, ) @@ -1616,7 +1616,7 @@ async def _get_response() -> ChatResponse: finish_reason=finish_reason, output=True, ) - return response # type: ignore[return-value,no-any-return] + return response return _get_response() @@ -1688,7 +1688,7 @@ async def get_embeddings( operation_duration_histogram=self.duration_histogram, duration=duration, ) - return result # type: ignore[no-any-return] + return result class AgentTelemetryLayer: @@ -1774,7 +1774,7 @@ def _record_duration() -> None: if isinstance(run_result, ResponseStream): result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] elif isinstance(run_result, Awaitable): - result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") except Exception as exception: @@ -1873,7 +1873,7 @@ async def _run() -> AgentResponse[Any]: messages=response.messages, output=True, ) - return response # type: ignore[return-value,no-any-return] + return response finally: INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token) INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token) diff --git a/python/packages/core/agent_framework/security.py b/python/packages/core/agent_framework/security.py index 6d3b1d0d599..f9e7a3b1e5c 100644 --- a/python/packages/core/agent_framework/security.py +++ b/python/packages/core/agent_framework/security.py @@ -536,7 +536,7 @@ def __init__( if isinstance(content, str): contents = [content] elif isinstance(content, list): - contents = cast(list[Any], content) # type: ignore[redundant-cast] + contents = cast(list[Any], content) else: contents = [str(content)] if content is not None else [] @@ -724,7 +724,7 @@ def parse_confidentiality_from_readers(conf_value: Any) -> ConfidentialityLabel: - ["user_id_1", "user_id_2", ...] means only those users → PRIVATE """ if isinstance(conf_value, list): - conf_candidates = cast(list[Any], conf_value) # type: ignore[redundant-cast] + conf_candidates = cast(list[Any], conf_value) conf_list: list[str] = [item for item in conf_candidates if isinstance(item, str)] if len(conf_list) == 1 and conf_list[0].lower() == "public": return ConfidentialityLabel.PUBLIC @@ -975,7 +975,7 @@ def _extract_labels_recursive(value: Any) -> None: for v in value_dict.values(): _extract_labels_recursive(v) elif isinstance(value, (list, tuple)): - value_items = cast(list[Any] | tuple[Any, ...], value) # type: ignore[redundant-cast] + value_items = cast(list[Any] | tuple[Any, ...], value) # Recurse into list/tuple items for item in value_items: _extract_labels_recursive(item) @@ -1034,7 +1034,7 @@ def _ensure_content_list(result: Any) -> list[Content]: import json as _json if isinstance(result, list): - result_list = cast(list[Any], result) # type: ignore[redundant-cast] + result_list = cast(list[Any], result) if all(isinstance(c, Content) for c in result_list): return cast(list[Content], result_list) if isinstance(result, Content): diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index bc03bb02dda..a20dc75f5b6 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -90,7 +90,7 @@ extend = "../../pyproject.toml" [tool.pyright] extends = "../../pyproject.toml" -include = ["agent_framework", "tests/workflow"] +include = ["agent_framework"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/core/tests/__init__.py b/python/packages/core/tests/__init__.py index e69de29bb2d..2a50eae8941 100644 --- a/python/packages/core/tests/__init__.py +++ b/python/packages/core/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/core/tests/conftest.py b/python/packages/core/tests/conftest.py index b4900808f42..1627da85cd2 100644 --- a/python/packages/core/tests/conftest.py +++ b/python/packages/core/tests/conftest.py @@ -78,7 +78,7 @@ def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_da ): exporter = InMemorySpanExporter() if enable_instrumentation or enable_sensitive_data: - tracer_provider = trace.get_tracer_provider() + tracer_provider = trace.get_tracer_provider() # type: ignore[assignment] if not hasattr(tracer_provider, "add_span_processor"): raise RuntimeError("Tracer provider does not support adding span processors.") diff --git a/python/packages/core/tests/core/__init__.py b/python/packages/core/tests/core/__init__.py index e69de29bb2d..2a50eae8941 100644 --- a/python/packages/core/tests/core/__init__.py +++ b/python/packages/core/tests/core/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index 3841db32c31..d67b9ef223c 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -6,6 +6,7 @@ import warnings from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence from typing import Any, Generic +from typing import TypedDict as TypedDict # noqa: F401 # pydantic mypy plugin needs TypedDict in module scope from unittest.mock import patch from uuid import uuid4 @@ -148,12 +149,12 @@ def __init__(self, **kwargs: Any): self.call_count: int = 0 @override - def _inner_get_response( + def _inner_get_response( # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self, *, - messages: MutableSequence[Message], + messages: MutableSequence[Message], # type: ignore[override] stream: bool, - options: dict[str, Any], + options: dict[str, Any], # type: ignore[override] **kwargs: Any, ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: """Send a chat request to the AI service. @@ -271,19 +272,19 @@ class MockAgentSession(AgentSession): # Mock Agent implementation for testing class MockAgent(SupportsAgentRun): @property - def id(self) -> str: + def id(self) -> str: # type: ignore[override] # pyrefly: ignore[bad-override] return str(uuid4()) @property - def name(self) -> str | None: + def name(self) -> str | None: # type: ignore[override] # pyrefly: ignore[bad-override] """Returns the name of the agent.""" return "Name" @property - def description(self) -> str | None: + def description(self) -> str | None: # type: ignore[override] # pyrefly: ignore[bad-override] return "Description" - def run( + def run( # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self, messages: str | Message | list[str] | list[Message] | None = None, *, @@ -315,7 +316,7 @@ async def _run_stream_impl( logger.debug(f"Running mock agent stream, with: {messages=}, {session=}, {kwargs=}") yield AgentResponseUpdate(contents=[Content.from_text("Response")]) - def create_session(self) -> AgentSession: + def create_session(self) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return MockAgentSession() @@ -326,4 +327,4 @@ def agent_session() -> AgentSession: @fixture def agent() -> SupportsAgentRun: - return MockAgent() + return MockAgent() # type: ignore[abstract] # pyrefly: ignore[bad-instantiation] diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 9cbd3916376..4de14fa8077 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -74,7 +74,7 @@ def __init__(self, name: str, function_names: list[str], *, tool_name_prefix: st ) ) - def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]: + def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] raise NotImplementedError @@ -169,7 +169,8 @@ def test_chat_client_agent_uses_client_model_attribute(chat_client_base) -> None def test_chat_client_agent_prefers_default_model_over_client_model(chat_client_base) -> None: chat_client_base.model = "legacy-model" # type: ignore[attr-defined] - agent = Agent(client=chat_client_base, default_options={"model": "claude-model"}) + default_options: ChatOptions = {"model": "claude-model"} + agent = Agent(client=chat_client_base, default_options=default_options) assert agent.default_options["model"] == "claude-model" assert "model_id" not in agent.default_options @@ -221,7 +222,7 @@ async def test_chat_client_agent_init_with_name( def test_agent_init_rejects_direct_additional_properties(client: SupportsChatGetResponse) -> None: with pytest.raises(TypeError): - Agent(client=client, legacy_key="legacy-value") + Agent(client=client, legacy_key="legacy-value") # type: ignore[call-arg] # pyrefly: ignore[unexpected-keyword] # ty: ignore[unknown-argument] async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None: @@ -250,7 +251,7 @@ class Greeting(BaseModel): greeting: str json_text = '{"greeting": "Hello"}' - client.streaming_responses.append( # type: ignore[attr-defined] + client.streaming_responses.append( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text(json_text)], @@ -260,7 +261,7 @@ class Greeting(BaseModel): ] ) - agent = Agent(client=client, default_options={"response_format": Greeting}) + agent = Agent(client=client, default_options={"response_format": Greeting}) # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] stream = agent.run("Hello", stream=True) async for _ in stream: pass @@ -282,7 +283,7 @@ class Greeting(BaseModel): greeting: str json_text = '{"greeting": "Hi"}' - client.streaming_responses.append( # type: ignore[attr-defined] + client.streaming_responses.append( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text(json_text)], @@ -309,11 +310,11 @@ async def test_chat_client_agent_response_format_dict_from_default_options( ) -> None: """AgentResponse.value should parse JSON dicts from default_options response_format.""" json_text = json.dumps({"greeting": "Hello"}) - client.responses.append(ChatResponse(messages=Message(role="assistant", contents=[json_text]))) # type: ignore[attr-defined] + client.responses.append(ChatResponse(messages=Message(role="assistant", contents=[json_text]))) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] agent = Agent( - client=client, - default_options={"response_format": {"type": "object", "properties": {"greeting": {"type": "string"}}}}, + client=client, # ty: ignore[invalid-argument-type] + default_options={"response_format": {"type": "object", "properties": {"greeting": {"type": "string"}}}}, # pyrefly: ignore[bad-argument-type] ) result = await agent.run("Hello") @@ -328,7 +329,7 @@ async def test_chat_client_agent_streaming_response_format_dict_from_run_options ) -> None: """Agent streaming should preserve mapping response_format and parse the final value as a dict.""" json_text = json.dumps({"greeting": "Hi"}) - client.streaming_responses.append( # type: ignore[attr-defined] + client.streaming_responses.append( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text(json_text)], @@ -373,7 +374,7 @@ async def test_chat_client_agent_prepare_session_and_messages( session = AgentSession() session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID] = {"messages": [message]} - session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Test"])], ) @@ -394,7 +395,7 @@ async def test_prepare_session_does_not_mutate_agent_chat_options( base_tools = agent.default_options["tools"] session = agent.create_session() - _, prepared_chat_options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, prepared_chat_options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Test"])], ) @@ -412,7 +413,7 @@ async def test_prepare_run_context_handles_function_kwargs( agent = Agent(client=chat_client_base) session = agent.create_session() - ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage] + ctx = await agent._prepare_run_context( # pyright: ignore[reportPrivateUsage] messages="Hello", session=session, tools=None, @@ -451,7 +452,7 @@ def lookup_weather(location: str) -> str: Message(role="assistant", contents=["Earlier answer"]), ] } - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -484,7 +485,7 @@ def lookup_weather(location: str) -> str: assert result.text == "It is sunny in Seattle." assert result.response_id is None - assert chat_client_base.call_count == 2 + assert chat_client_base.call_count == 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert provider_state["get_call_count"] == 2 assert provider_state["save_call_count"] == 2 assert stored_messages[-1].text == "It is sunny in Seattle." @@ -507,7 +508,7 @@ def lookup_weather(location: str) -> str: Message(role="assistant", contents=["Earlier answer"]), ] } - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -549,7 +550,7 @@ def lookup_weather(location: str) -> str: assert result.text == "It is sunny in Seattle." assert result.response_id is None - assert chat_client_base.call_count == 2 + assert chat_client_base.call_count == 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert provider_state["get_call_count"] == 2 assert provider_state["save_call_count"] == 2 assert stored_messages[-1].text == "It is sunny in Seattle." @@ -567,7 +568,7 @@ def lookup_weather(location: str) -> str: session = AgentSession() session.state[provider.source_id] = {"messages": []} - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -620,11 +621,11 @@ async def test_per_service_call_persistence_uses_real_service_storage_when_clien def lookup_weather(location: str) -> str: return f"Weather in {location}: sunny" - chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session = AgentSession() session.state[provider.source_id] = {"messages": []} - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -660,7 +661,7 @@ def lookup_weather(location: str) -> str: assert result.text == "It is sunny in Seattle." assert result.response_id == "resp_call_2" - assert chat_client_base.call_count == 2 + assert chat_client_base.call_count == 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # The service owns the conversation, so the provider never loads (issue #5798). assert "get_call_count" not in provider_state # Persistence is owned by the per-service-call middleware: it persists once per service call @@ -681,7 +682,7 @@ async def test_service_storage_updates_session_handle_per_service_call_before_no def lookup_weather(location: str) -> str: return f"Weather in {location}: sunny" - chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session = AgentSession() session.state[provider.source_id] = {"messages": []} @@ -729,7 +730,7 @@ async def test_service_storage_updates_session_handle_per_service_call_before_st def lookup_weather(location: str) -> str: return f"Weather in {location}: sunny" - chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session = AgentSession() session.state[provider.source_id] = {"messages": []} @@ -788,7 +789,7 @@ def _finalize_first_stream(_updates: Sequence[ChatResponseUpdate]) -> ChatRespon async def test_chat_agent_without_per_service_call_persistence_preserves_response_id( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message(role="assistant", contents=["Hello"]), response_id="resp_call_1", @@ -809,10 +810,10 @@ async def test_per_service_call_persistence_rejects_real_service_conversation_id chat_client_base: SupportsChatGetResponse, ) -> None: provider = _RecordingHistoryProvider() - chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + chat_client_base.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session = AgentSession() session.state[provider.source_id] = {"messages": []} - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message(role="assistant", contents=["Hello"]), conversation_id="resp_service_managed", @@ -857,7 +858,7 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="123", ) - chat_client_base.run_responses = [mock_response] + chat_client_base.run_responses = [mock_response] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] agent = Agent( client=chat_client_base, tools={"type": "code_interpreter"}, @@ -873,7 +874,7 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat async def test_chat_client_agent_updates_existing_session_id_non_streaming( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="resp_new_123", @@ -890,7 +891,7 @@ async def test_chat_client_agent_updates_existing_session_id_non_streaming( async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text("stream part 1")], @@ -922,7 +923,7 @@ async def test_chat_client_agent_update_session_id_streaming_uses_conversation_i async def test_chat_client_agent_updates_existing_session_id_streaming( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text("stream part 1")], @@ -953,7 +954,7 @@ async def test_chat_client_agent_updates_existing_session_id_streaming( async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text("stream response without conversation id")], @@ -985,7 +986,7 @@ async def test_chat_client_agent_streaming_session_id_set_without_get_final_resp user iterates the stream and then makes a follow-up call without calling get_final_response(). """ - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text("part 1")], @@ -1024,7 +1025,7 @@ async def test_chat_client_agent_streaming_session_history_saved_without_get_fin """ from agent_framework._sessions import InMemoryHistoryProvider - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_text("Hello Alice!")], @@ -1104,7 +1105,7 @@ async def test_chat_client_agent_author_name_as_agent_name( async def test_chat_client_agent_author_name_is_used_from_response( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=[ Message( @@ -1164,7 +1165,7 @@ async def test_chat_agent_context_providers_after_run( ) -> None: """Test that context providers' after_run is called during agent run.""" mock_provider = MockContextProvider() - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="test-thread-id", @@ -1206,7 +1207,7 @@ async def test_chat_agent_context_instructions_in_messages( ) # We need to test the _prepare_session_and_messages method directly - session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=None, input_messages=[Message(role="user", contents=["Hello"])] ) messages = session_context.get_messages(include_input=True) @@ -1231,7 +1232,7 @@ async def test_chat_agent_no_context_instructions( context_providers=[mock_provider], ) - session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=None, input_messages=[Message(role="user", contents=["Hello"])] ) messages = session_context.get_messages(include_input=True) @@ -1267,7 +1268,7 @@ async def test_chat_agent_context_providers_with_service_session_id( ) -> None: """Test context providers with service-managed session.""" mock_provider = MockContextProvider() - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="service-thread-123", @@ -1469,7 +1470,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: captured_session = kwargs.get("session") return original_run(*args, **kwargs) - agent.run = capturing_run # type: ignore[assignment, method-assign] + agent.run = capturing_run # type: ignore[assignment, method-assign] # ty: ignore[invalid-assignment] await tool.invoke( context=FunctionInvocationContext( @@ -1480,6 +1481,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: ) assert captured_session is parent_session + assert captured_session is not None assert captured_session.session_id == "parent-session-123" assert captured_session.state["shared_key"] == "shared_value" @@ -1499,7 +1501,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: captured_session = kwargs.get("session") return original_run(*args, **kwargs) - agent.run = capturing_run # type: ignore[assignment, method-assign] + agent.run = capturing_run # type: ignore[assignment, method-assign] # ty: ignore[invalid-assignment] await tool.invoke( context=FunctionInvocationContext( @@ -1530,7 +1532,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: captured_session.state["counter"] += 1 return original_run(*args, **kwargs) - agent.run = capturing_run # type: ignore[assignment, method-assign] + agent.run = capturing_run # type: ignore[assignment, method-assign] # ty: ignore[invalid-assignment] await tool.invoke( context=FunctionInvocationContext( @@ -1834,7 +1836,7 @@ async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specif async def capturing_inner( *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: - captured_options.append(options) + captured_options.append(options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] return await original_inner(messages=messages, options=options, **kwargs) chat_client_base._inner_get_response = capturing_inner @@ -1865,7 +1867,7 @@ async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level(chat_cli async def capturing_inner( *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any ) -> ChatResponse: - captured_options.append(options) + captured_options.append(options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] return await original_inner(messages=messages, options=options, **kwargs) chat_client_base._inner_get_response = capturing_inner @@ -1878,7 +1880,7 @@ async def capturing_inner( ) # Run with explicitly passing None (same as not specifying) - await agent.run("Hello", options={"tool_choice": None}) + await agent.run("Hello", options={"tool_choice": None}) # ty: ignore[no-matching-overload] # type: ignore[typeddict-item] # Verify the client received tool_choice="auto" from agent-level assert len(captured_options) >= 1 @@ -2261,8 +2263,8 @@ def test_sanitize_agent_name_replaces_invalid_chars(): """Test _sanitize_agent_name replaces invalid characters.""" result = _sanitize_agent_name("Agent Name!") # Should replace spaces and special chars with underscores - assert " " not in result - assert "!" not in result + assert " " not in result # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "!" not in result # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # endregion @@ -2317,7 +2319,7 @@ async def test_agent_get_session_with_service_session_id( def test_agent_session_from_dict(chat_client_base: SupportsChatGetResponse, tool_tool: FunctionTool): """Test AgentSession.from_dict restores a session from serialized state.""" # Create serialized session state - serialized_state = { + serialized_state = { # type: ignore[var-annotated] "type": "session", "session_id": "test-session", "service_session_id": None, @@ -2377,7 +2379,7 @@ async def before_run(self, *, agent, session, context, state): assert agent.default_options.get("tools") == [] # Run the agent and verify context tools are added - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=None, input_messages=[Message(role="user", contents=["Hello"])] ) @@ -2406,7 +2408,7 @@ async def before_run(self, *, agent, session, context, state): assert agent.default_options.get("instructions") is None # Run the agent and verify context instructions are available - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=None, input_messages=[Message(role="user", contents=["Hello"])] ) @@ -2432,7 +2434,7 @@ async def before_run(self, *, agent, session, context, state) -> None: agent = Agent(client=chat_client_base, context_providers=[MiddlewareContextProvider()]) - session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=None, input_messages=[Message(role="user", contents=["Hello"])], ) @@ -2451,7 +2453,7 @@ async def test_stores_by_default_skips_inmemory_injection( from agent_framework._sessions import InMemoryHistoryProvider # Simulate a client that stores by default - client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] agent = Agent(client=client) session = agent.create_session() @@ -2483,7 +2485,7 @@ async def test_stores_by_default_with_store_false_injects_inmemory( """Client with STORES_BY_DEFAULT=True but store=False should still inject InMemoryHistoryProvider.""" from agent_framework._sessions import InMemoryHistoryProvider - client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] agent = Agent(client=client) session = agent.create_session() @@ -2519,10 +2521,10 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm """ from agent_framework._sessions import InMemoryHistoryProvider - client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] + client.STORES_BY_DEFAULT = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Set store=False at agent initialization via default_options, not at run-time - agent = Agent(client=client, default_options={"store": False}) + agent = Agent(client=client, default_options={"store": False}) # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session = agent.create_session() # Run without any per-run options override @@ -2546,9 +2548,9 @@ def search_hotels(city: str) -> str: responses_client = OpenAIChatClient(model="test-model", api_key="test-key") responses_agent = Agent( - client=responses_client, + client=responses_client, # ty: ignore[invalid-argument-type] tools=[search_hotels], - default_options={"store": False}, + default_options={"store": False}, # pyrefly: ignore[bad-argument-type] ) session = responses_agent.create_session() @@ -2669,7 +2671,7 @@ async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetRespo consent_content = Content.from_oauth_consent_request( consent_link="https://login.microsoftonline.com/consent", ) - client.streaming_responses = [ # type: ignore[attr-defined] + client.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ChatResponseUpdate(contents=[consent_content], role="assistant")], ] @@ -2772,7 +2774,7 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined] + self.STORES_BY_DEFAULT = stores_by_default # type: ignore[attr-defined, misc] # ty: ignore[invalid-attribute-access] self._provider = provider self._script = list(script) if script is not None else [("text", "ok")] self._echo_conversation_id = echo_conversation_id @@ -2787,8 +2789,9 @@ def _effective_store(self, options: dict[str, Any]) -> bool: return bool(store) def _next_contents(self) -> list[Content]: - turn = self._script.pop(0) if self._script else ("text", "ok") + turn: tuple[str, ...] = self._script.pop(0) if self._script else ("text", "ok") if turn[0] == "call": + assert len(turn) == 4 _, call_id, name, args = turn return [Content.from_function_call(call_id=call_id, name=name, arguments=args)] return [Content.from_text(turn[1])] diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index d8acc8ac1a0..178528510a4 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -45,7 +45,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Create sub-agent with middleware sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) @@ -86,7 +86,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) @@ -139,14 +139,14 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Create agent C (bottom level) agent_c = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="agent_c", middleware=[capture_middleware], ) # Create agent B (middle level) - delegates to C agent_b = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="agent_b", tools=[agent_c.as_tool(name="call_c")], middleware=[capture_middleware], @@ -190,7 +190,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) @@ -223,7 +223,7 @@ async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", ) @@ -252,7 +252,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) @@ -300,7 +300,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) @@ -346,7 +346,7 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai ] sub_agent = Agent( - client=client, + client=client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] name="sub_agent", middleware=[capture_middleware], ) diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index ac110a4a173..c5be46c450f 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -53,11 +53,11 @@ def test_base_client(chat_client_base: SupportsChatGetResponse): def test_base_client_rejects_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: with pytest.raises(TypeError): - type(chat_client_base)(legacy_key="legacy-value") + type(chat_client_base)(legacy_key="legacy-value") # type: ignore[call-arg] # pyrefly: ignore[bad-instantiation, unexpected-keyword] # ty: ignore[unknown-argument] def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: - agent = chat_client_base.as_agent(additional_properties={"team": "core"}) + agent = chat_client_base.as_agent(additional_properties={"team": "core"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert agent.additional_properties == {"team": "core"} @@ -95,10 +95,10 @@ async def test_base_client_get_response_streaming(chat_client_base: SupportsChat async def test_base_client_applies_compaction_before_non_streaming_inner_call( chat_client_base: SupportsChatGetResponse, ): - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] - chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_roles: list[list[str]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -109,7 +109,7 @@ async def _capture( captured_roles.append([message.role for message in messages]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response([ Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Previous response"]), @@ -120,10 +120,10 @@ async def _capture( async def test_base_client_applies_compaction_before_streaming_inner_call( chat_client_base: SupportsChatGetResponse, ): - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] - chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_roles: list[list[str]] = [] - original = chat_client_base._get_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def _capture( *, @@ -134,7 +134,7 @@ def _capture( captured_roles.append([message.role for message in messages]) return original(messages=messages, options=options, **kwargs) - chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] async for _ in chat_client_base.get_response( [ Message(role="user", contents=["Hello"]), @@ -149,9 +149,9 @@ def _capture( async def test_base_client_per_call_compaction_override_applies_before_inner_call( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_roles: list[list[str]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -162,7 +162,7 @@ async def _capture( captured_roles.append([message.role for message in messages]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [ Message(role="user", contents=["Hello"]), @@ -176,9 +176,9 @@ async def _capture( async def test_base_client_per_call_tokenizer_override_annotates_messages( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_token_counts: list[list[int | None]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -192,7 +192,7 @@ async def _capture( ]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [ Message(role="user", contents=["Hello"]), @@ -207,9 +207,9 @@ async def _capture( async def test_base_client_per_call_tokenizer_override_without_strategy_annotates_messages( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_token_counts: list[list[int | None]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -223,7 +223,7 @@ async def _capture( ]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [ Message(role="user", contents=["Hello"]), @@ -237,10 +237,10 @@ async def _capture( async def test_base_client_default_tokenizer_without_strategy_annotates_messages( chat_client_base: SupportsChatGetResponse, ) -> None: - chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] - chat_client_base.tokenizer = _FixedTokenizer(19) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.tokenizer = _FixedTokenizer(19) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] captured_token_counts: list[list[int | None]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -254,7 +254,7 @@ async def _capture( ]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response([ Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Previous response"]), @@ -290,22 +290,22 @@ async def test_function_loop_persists_inserted_summaries_across_iterations( # originals. Across tool-loop iterations the exclusion flags persisted (shared Message # objects) but the inserted summaries were dropped (they only lived on a throwaway copy), # so older tool groups were silently lost with no summary representing them. - chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] @tool(name="lookup_weather", approval_mode="never_require") def lookup_weather(location: str) -> str: return f"Weather in {location}: sunny" - chat_client_base.run_responses = [ # type: ignore[attr-defined] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _tool_call_response("call_1", "London"), _tool_call_response("call_2", "Paris"), _tool_call_response("call_3", "Tokyo"), ] captured_inputs: list[list[Message]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -316,7 +316,7 @@ async def _capture( captured_inputs.append(list(messages)) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [Message(role="user", contents=["What is the weather in London?"])], @@ -356,22 +356,22 @@ async def test_function_loop_persists_inserted_summaries_across_iterations_strea ) -> None: # Streaming counterpart of the #4991 regression test: the summary persistence fix in # ``_prepare_messages_for_model_call`` must cover the streaming tool loop too. - chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] @tool(name="lookup_weather", approval_mode="never_require") def lookup_weather(location: str) -> str: return f"Weather in {location}: sunny" - chat_client_base.streaming_responses = [ # type: ignore[attr-defined] + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _tool_call_update("call_1", "London"), _tool_call_update("call_2", "Paris"), _tool_call_update("call_3", "Tokyo"), ] captured_inputs: list[list[Message]] = [] - original = chat_client_base._get_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def _capture( *, @@ -382,7 +382,7 @@ def _capture( captured_inputs.append(list(messages)) return original(messages=messages, options=options, **kwargs) - chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] stream = chat_client_base.get_response( [Message(role="user", contents=["What is the weather in London?"])], @@ -407,9 +407,9 @@ async def test_function_loop_compaction_conversation_id_mode_does_not_resend_his # In conversation-id mode the server owns prior context, so the tool loop clears # ``prepped_messages`` and only sends the latest message. Compaction must not fight that # by re-inserting summaries or re-sending earlier turns. - chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] + chat_client_base.function_invocation_configuration["enabled"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] @tool(name="lookup_weather", approval_mode="never_require") def lookup_weather(location: str) -> str: @@ -420,14 +420,14 @@ def _conversation_tool_call(call_id: str, location: str) -> ChatResponse: response.conversation_id = "conv_1" return response - chat_client_base.run_responses = [ # type: ignore[attr-defined] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _conversation_tool_call("call_1", "London"), _conversation_tool_call("call_2", "Paris"), _conversation_tool_call("call_3", "Tokyo"), ] captured_inputs: list[list[Message]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -438,7 +438,7 @@ async def _capture( captured_inputs.append(list(messages)) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [Message(role="user", contents=["What is the weather in London?"])], @@ -457,10 +457,10 @@ def test_base_client_as_agent_does_not_copy_client_compaction_defaults( ) -> None: strategy = TruncationStrategy(max_n=1, compact_to=1) tokenizer = _FixedTokenizer(11) - chat_client_base.compaction_strategy = strategy # type: ignore[attr-defined] - chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined] + chat_client_base.compaction_strategy = strategy # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] - agent = chat_client_base.as_agent(name="shared-client-agent") + agent = chat_client_base.as_agent(name="shared-client-agent") # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert agent.compaction_strategy is None # type: ignore[attr-defined] assert agent.tokenizer is None # type: ignore[attr-defined] diff --git a/python/packages/core/tests/core/test_compaction.py b/python/packages/core/tests/core/test_compaction.py index 99a90c6c0d6..aa5fc79a454 100644 --- a/python/packages/core/tests/core/test_compaction.py +++ b/python/packages/core/tests/core/test_compaction.py @@ -399,7 +399,7 @@ async def test_summarization_strategy_adds_bidirectional_trace_links() -> None: Message(role="user", contents=["u3"]), Message(role="assistant", contents=["a3"]), ] - strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0) + strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] annotate_message_groups(messages) changed = await strategy(messages) @@ -432,7 +432,7 @@ async def test_summarization_strategy_returns_false_when_summary_generation_fail Message(role="user", contents=["u3"]), Message(role="assistant", contents=["a3"]), ] - strategy = SummarizationStrategy(client=_FailingSummarizer(), target_count=2, threshold=0) + strategy = SummarizationStrategy(client=_FailingSummarizer(), target_count=2, threshold=0) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] annotate_message_groups(messages) with caplog.at_level(logging.WARNING, logger="agent_framework"): @@ -454,7 +454,7 @@ async def test_summarization_strategy_returns_false_when_summary_is_empty( Message(role="user", contents=["u3"]), Message(role="assistant", contents=["a3"]), ] - strategy = SummarizationStrategy(client=_EmptySummarizer(), target_count=2, threshold=0) + strategy = SummarizationStrategy(client=_EmptySummarizer(), target_count=2, threshold=0) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] annotate_message_groups(messages) with caplog.at_level(logging.WARNING, logger="agent_framework"): @@ -721,7 +721,7 @@ async def test_summarization_strategy_summary_has_full_annotations() -> None: Message(role="user", contents=["u3"]), Message(role="assistant", contents=["a3"]), ] - strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0) + strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] annotate_message_groups(messages) changed = await strategy(messages) diff --git a/python/packages/core/tests/core/test_embedding_client.py b/python/packages/core/tests/core/test_embedding_client.py index 38a7b74bdc9..b2e6aead214 100644 --- a/python/packages/core/tests/core/test_embedding_client.py +++ b/python/packages/core/tests/core/test_embedding_client.py @@ -24,9 +24,9 @@ async def get_embeddings( *, options: EmbeddingGenerationOptions | None = None, ) -> GeneratedEmbeddings[list[float]]: - return GeneratedEmbeddings( + return GeneratedEmbeddings( # ty: ignore[invalid-return-type] [Embedding(vector=[0.1, 0.2, 0.3], model="mock-model") for _ in values], - usage={"prompt_tokens": len(values), "total_tokens": len(values)}, + usage={"prompt_tokens": len(values), "total_tokens": len(values)}, # type: ignore[arg-type] # ty: ignore[invalid-argument-type, invalid-key] ) @@ -52,7 +52,7 @@ async def test_base_get_embeddings_usage() -> None: client = MockEmbeddingClient() result = await client.get_embeddings(["a", "b", "c"]) assert result.usage is not None - assert result.usage["prompt_tokens"] == 3 + assert result.usage["prompt_tokens"] == 3 # type: ignore[typeddict-item] # ty: ignore[invalid-key] def test_base_additional_properties_default() -> None: @@ -67,7 +67,7 @@ def test_base_additional_properties_custom() -> None: def test_base_embedding_client_rejects_unknown_kwargs() -> None: with pytest.raises(TypeError): - MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg] + MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg] # ty: ignore[unknown-argument] # --- SupportsGetEmbeddings protocol tests --- diff --git a/python/packages/core/tests/core/test_embedding_types.py b/python/packages/core/tests/core/test_embedding_types.py index 95c5ede6122..2e2a23f68de 100644 --- a/python/packages/core/tests/core/test_embedding_types.py +++ b/python/packages/core/tests/core/test_embedding_types.py @@ -61,7 +61,7 @@ def test_embedding_dimensions_explicit_with_unknown_type() -> None: def test_embedding_empty_vector() -> None: - embedding = Embedding(vector=[]) + embedding = Embedding(vector=[]) # type: ignore[var-annotated] assert embedding.dimensions == 0 @@ -99,10 +99,10 @@ def test_generated_construction_with_usage() -> None: model="test-model", ) ], - usage=usage, + usage=usage, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] ) assert embeddings.usage == usage - assert embeddings.usage["prompt_tokens"] == 10 + assert embeddings.usage["prompt_tokens"] == 10 # type: ignore[index, typeddict-item] # pyrefly: ignore[unsupported-operation] # ty: ignore[invalid-key, not-subscriptable] def test_generated_construction_with_additional_properties() -> None: diff --git a/python/packages/core/tests/core/test_feature_stage.py b/python/packages/core/tests/core/test_feature_stage.py index 040b32cbc1d..a9fdc70f5bc 100644 --- a/python/packages/core/tests/core/test_feature_stage.py +++ b/python/packages/core/tests/core/test_feature_stage.py @@ -4,6 +4,7 @@ import inspect import warnings +from collections.abc import Generator from enum import Enum from typing import Protocol, runtime_checkable @@ -45,7 +46,7 @@ class HelperReleaseCandidateFeature(str, Enum): @pytest.fixture(autouse=True) -def clear_feature_warning_state() -> None: +def clear_feature_warning_state() -> Generator[None]: # type: ignore[misc] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] _WARNED_FEATURES.clear() yield _WARNED_FEATURES.clear() @@ -60,7 +61,7 @@ def test_experimental_decorator_accepts_feature_enum() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def skill_function() -> None: pass @@ -72,14 +73,14 @@ def skill_function() -> None: assert len(caught) == 1 assert f"[{AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value}]" in str(caught[0].message) assert "skill_function" in str(caught[0].message) - assert skill_function.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value + assert skill_function.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_experimental_function_warns_on_call_and_not_on_definition() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def my_function(value: int) -> int: """Double the input. @@ -100,8 +101,8 @@ def my_function(value: int) -> int: assert len(caught) == 1 assert f"[{AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value}]" in str(caught[0].message) assert "my_function" in str(caught[0].message) - assert my_function.__feature_stage__ == "experimental" - assert my_function.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value + assert my_function.__feature_stage__ == "experimental" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert my_function.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert my_function.__doc__ is not None lines = my_function.__doc__.splitlines() warning_index = next(i for i, line in enumerate(lines) if line == ".. warning:: Experimental") @@ -113,7 +114,7 @@ def test_experimental_class_warns_on_instantiation_and_not_on_definition() -> No with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class ExperimentalClass: """An experimental class. @@ -127,7 +128,7 @@ def __init__(self, value: int) -> None: assert not caught with warnings.catch_warnings(record=True) as caught: - instantiation_line = inspect.currentframe().f_lineno + 1 + instantiation_line = inspect.currentframe().f_lineno + 1 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] instance = ExperimentalClass(4) second_instance = ExperimentalClass(5) @@ -138,8 +139,8 @@ def __init__(self, value: int) -> None: assert caught[0].lineno == instantiation_line assert instance.value == 4 assert second_instance.value == 5 - assert ExperimentalClass.__feature_stage__ == "experimental" - assert ExperimentalClass.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value + assert ExperimentalClass.__feature_stage__ == "experimental" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ExperimentalClass.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_experimental_abc_subclass_warning_points_at_user_file() -> None: @@ -151,14 +152,14 @@ def test_experimental_abc_subclass_warning_points_at_user_file() -> None: """ from abc import ABC, abstractmethod - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class ExperimentalABC(ABC): @abstractmethod def do(self) -> int: ... with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - subclass_line = inspect.currentframe().f_lineno + 1 + subclass_line = inspect.currentframe().f_lineno + 1 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] class Concrete(ExperimentalABC): def do(self) -> int: @@ -180,7 +181,7 @@ def test_experimental_runtime_checkable_protocol_keeps_protocol_runtime_checks() warnings.simplefilter("always") @runtime_checkable - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class ExampleProtocol(Protocol): """A protocol used for runtime checks. @@ -206,11 +207,11 @@ def test_experimental_warning_is_emitted_once_per_feature() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id=AlternateExperimentalFeature.SHARED_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.SHARED_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def first() -> None: pass - @experimental(feature_id=AlternateExperimentalFeature.SHARED_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.SHARED_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class Second: pass @@ -260,8 +261,8 @@ def __init__(self, value: int) -> None: assert instance.value == 5 assert not caught - assert ReleaseCandidateClass.__feature_stage__ == "release_candidate" - assert ReleaseCandidateClass.__feature_id__ == HelperReleaseCandidateFeature.RC_FEATURE.value + assert ReleaseCandidateClass.__feature_stage__ == "release_candidate" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ReleaseCandidateClass.__feature_id__ == HelperReleaseCandidateFeature.RC_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert ReleaseCandidateClass.__doc__ is not None assert ".. note:: Release candidate" in ReleaseCandidateClass.__doc__ @@ -272,7 +273,7 @@ def test_experimental_property_warns_on_access_and_not_on_definition() -> None: class Example: @property - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def value(self) -> int: """Return the value. @@ -302,7 +303,7 @@ def test_experimental_staticmethod_warns_when_decorator_wraps_descriptor() -> No warnings.simplefilter("always") class Example: - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] @staticmethod def value() -> int: """Return the value. @@ -321,7 +322,7 @@ def value() -> int: assert len(caught) == 1 assert f"[{AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value}]" in str(caught[0].message) assert "Example.value" in str(caught[0].message) - assert Example.value.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value + assert Example.value.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert Example.value.__doc__ is not None lines = Example.value.__doc__.splitlines() warning_index = next(i for i, line in enumerate(lines) if line == ".. warning:: Experimental") @@ -334,7 +335,7 @@ def test_experimental_classmethod_warns_when_decorator_wraps_descriptor() -> Non warnings.simplefilter("always") class Example: - @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.EXPERIMENTAL_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] @classmethod def value(cls) -> int: """Return the value. @@ -353,7 +354,7 @@ def value(cls) -> int: assert len(caught) == 1 assert f"[{AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value}]" in str(caught[0].message) assert "Example.value" in str(caught[0].message) - assert Example.value.__func__.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value + assert Example.value.__func__.__feature_id__ == AlternateExperimentalFeature.EXPERIMENTAL_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert Example.value.__doc__ is not None lines = Example.value.__doc__.splitlines() warning_index = next(i for i, line in enumerate(lines) if line == ".. warning:: Experimental") @@ -382,14 +383,14 @@ def lowercase_feature() -> None: assert len(caught) == 1 assert "[skills]" in str(caught[0].message) assert "lowercase_feature" in str(caught[0].message) - assert lowercase_feature.__feature_id__ == "skills" + assert lowercase_feature.__feature_id__ == "skills" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_experimental_decorator_allows_string_feature_id_at_runtime() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id="STRING_FEATURE") # type: ignore[arg-type] + @experimental(feature_id="STRING_FEATURE") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def skill_function() -> None: pass @@ -401,14 +402,14 @@ def skill_function() -> None: assert len(caught) == 1 assert "[STRING_FEATURE]" in str(caught[0].message) assert "skill_function" in str(caught[0].message) - assert skill_function.__feature_id__ == "STRING_FEATURE" + assert skill_function.__feature_id__ == "STRING_FEATURE" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_experimental_decorator_allows_other_enum_values_at_runtime() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @experimental(feature_id=AlternateExperimentalFeature.ALTERNATE_FEATURE) # type: ignore[arg-type] + @experimental(feature_id=AlternateExperimentalFeature.ALTERNATE_FEATURE) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def my_function() -> None: pass @@ -420,20 +421,20 @@ def my_function() -> None: assert len(caught) == 1 assert f"[{AlternateExperimentalFeature.ALTERNATE_FEATURE.value}]" in str(caught[0].message) assert "my_function" in str(caught[0].message) - assert my_function.__feature_id__ == AlternateExperimentalFeature.ALTERNATE_FEATURE.value + assert my_function.__feature_id__ == AlternateExperimentalFeature.ALTERNATE_FEATURE.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_release_candidate_decorator_allows_string_feature_id_at_runtime() -> None: with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - @release_candidate(feature_id="RC_FEATURE") # type: ignore[arg-type] + @release_candidate(feature_id="RC_FEATURE") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class ReleaseCandidateClass: """A release-candidate class.""" assert not caught - assert ReleaseCandidateClass.__feature_stage__ == "release_candidate" - assert ReleaseCandidateClass.__feature_id__ == "RC_FEATURE" + assert ReleaseCandidateClass.__feature_stage__ == "release_candidate" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ReleaseCandidateClass.__feature_id__ == "RC_FEATURE" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_feature_id_stringifies_non_string_enum_values() -> None: @@ -457,4 +458,4 @@ def numeric_feature() -> None: assert len(caught) == 1 assert "[1]" in str(caught[0].message) assert "numeric_feature" in str(caught[0].message) - assert numeric_feature.__feature_id__ == "1" + assert numeric_feature.__feature_id__ == "1" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] diff --git a/python/packages/core/tests/core/test_foundry_namespace.py b/python/packages/core/tests/core/test_foundry_namespace.py index 91953b82335..70db21d4066 100644 --- a/python/packages/core/tests/core/test_foundry_namespace.py +++ b/python/packages/core/tests/core/test_foundry_namespace.py @@ -21,4 +21,4 @@ def test_azure_namespace_no_longer_exposes_foundry_symbols() -> None: assert "FoundryLocalClient" not in dir(azure) with pytest.raises(AttributeError, match="Module `azure` has no attribute FoundryChatClient\\."): - _ = azure.FoundryChatClient + _ = azure.FoundryChatClient # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index f96ca99ad53..26ea5e4230e 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -8,6 +8,7 @@ from agent_framework import ( Agent, + ChatOptions, ChatResponse, ChatResponseUpdate, Content, @@ -58,7 +59,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -96,7 +97,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -108,7 +109,7 @@ def ai_func(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) + response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]}) # type: ignore[arg-type, call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert exec_counter == 1 assert len(response.messages) == 3 @@ -127,7 +128,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -182,7 +183,7 @@ async def __call__(self, messages: list[Message]) -> bool: return changed captured_roles: list[list[str]] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -193,10 +194,10 @@ async def _capture( captured_roles.append([message.role for message in messages]) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] - chat_client_base.compaction_strategy = _ExcludeOldestGroupAfterFirstTurn() # type: ignore[attr-defined] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = _ExcludeOldestGroupAfterFirstTurn() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -231,7 +232,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}. " + ("result " * 120) captured_token_counts: list[int] = [] - original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] + original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def _capture( *, @@ -243,15 +244,15 @@ async def _capture( captured_token_counts.append(included_token_count(messages)) return await original(messages=messages, options=options, **kwargs) - chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign] - chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined] - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.compaction_strategy = TokenBudgetComposedStrategy( # type: ignore[attr-defined] + chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined, method-assign] # ty: ignore[unresolved-attribute] + chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.compaction_strategy = TokenBudgetComposedStrategy( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] token_budget=token_budget, tokenizer=tokenizer, strategies=[SlidingWindowStrategy(keep_last_groups=2)], ) - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -292,7 +293,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1":')], @@ -311,8 +312,10 @@ def ai_func(arg1: str) -> str: ], ] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [ai_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [ai_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) assert len(updates) == 4 # two updates with the function call, the function result and the final text @@ -334,7 +337,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=[ Message( @@ -391,7 +394,7 @@ def ai_func(user_query: str) -> str: exec_counter += 1 return f"Investigated {user_query}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -422,7 +425,7 @@ async def handler(request: web.Request) -> web.Response: site = web.TCPSite(runner, "127.0.0.1", 0) await site.start() try: - port = site._server.sockets[0].getsockname()[1] + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async with aiohttp.ClientSession() as session, session.post(f"http://127.0.0.1:{port}/run") as response: assert response.status == 200 await response.text() @@ -448,7 +451,7 @@ def ai_func(user_query: str) -> str: exec_counter += 1 return f"Threaded {user_query}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -492,7 +495,7 @@ async def runner_main() -> None: await site.start() shutdown_event = asyncio.Event() shutdown_queue.put((loop, shutdown_event)) - port = site._server.sockets[0].getsockname()[1] + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] port_queue.put(port) ready_event.set() try: @@ -583,11 +586,11 @@ def func_with_approval(arg1: str) -> str: func_call = Content.from_function_call(call_id="1", name=function_name, arguments='{"arg1": "value1"}') completion = Message(role="assistant", contents=["done"]) - chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=[func_call]))] + ( + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=[func_call]))] + ( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [] if approval_required else [ChatResponse(messages=completion)] ) - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name=function_name, arguments='{"arg1":')], @@ -613,9 +616,9 @@ def func_with_approval(arg1: str) -> str: Content.from_function_call(call_id="2", name="approval_func", arguments='{"arg1": "value2"}'), ] - chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=func_calls))] + chat_client_base.run_responses = [ChatResponse(messages=Message(role="assistant", contents=func_calls))] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate(contents=[func_calls[0]], role="assistant"), ChatResponseUpdate(contents=[func_calls[1]], role="assistant"), @@ -623,18 +626,20 @@ def func_with_approval(arg1: str) -> str: ] # Execute the test - options: dict[str, Any] = {"tool_choice": "auto", "tools": tools} + options: ChatOptions = {"tool_choice": "auto", "tools": tools} if thread_type == "service": # For service threads, we need to pass conversation_id via options + assert conversation_id is not None options["store"] = True options["conversation_id"] = conversation_id + messages: list[Any] if not streaming: - response = await chat_client_base.get_response([Message(role="user", contents=["hello"])], options=options) + response = await chat_client_base.get_response([Message(role="user", contents=["hello"])], options=options) # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] messages = response.messages else: updates = [] - async for update in chat_client_base.get_response( + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] [Message(role="user", contents=["hello"])], options=options, stream=True ): updates.append(update) @@ -659,7 +664,9 @@ def func_with_approval(arg1: str) -> str: assert len(messages[0].contents) == 2 assert messages[0].contents[0].type == "function_call" assert messages[0].contents[1].type == "function_approval_request" - assert messages[0].contents[1].function_call.name == "approval_func" + function_call = messages[0].contents[1].function_call + assert function_call is not None + assert function_call.name == "approval_func" assert exec_counter == 0 # Function not executed yet else: # Streaming: 2 function call chunks + 1 approval request update (same assistant message) @@ -667,7 +674,9 @@ def func_with_approval(arg1: str) -> str: assert messages[0].contents[0].type == "function_call" assert messages[1].contents[0].type == "function_call" assert messages[2].contents[0].type == "function_approval_request" - assert messages[2].contents[0].function_call.name == "approval_func" + function_call = messages[2].contents[0].function_call + assert function_call is not None + assert function_call.name == "approval_func" assert exec_counter == 0 # Function not executed yet else: # Single function without approval: call + result + final @@ -731,7 +740,7 @@ def func_rejected(arg1: str) -> str: return f"Rejected {arg1}" # Setup: two function calls that require approval - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -745,8 +754,9 @@ def func_rejected(arg1: str) -> str: ] # Get the response with approval requests - response = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_approved, func_rejected]} + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_approved, func_rejected]}, # type: ignore[arg-type] ) # Approval requests are now added to the assistant message, not a separate message assert len(response.messages) == 1 @@ -760,13 +770,13 @@ def func_rejected(arg1: str) -> str: approval_req_2 = approval_requests[1] approved_response = Content.from_function_approval_response( - id=approval_req_1.id, - function_call=approval_req_1.function_call, + id=approval_req_1.id, # type: ignore[arg-type] + function_call=approval_req_1.function_call, # type: ignore[arg-type] approved=True, ) rejected_response = Content.from_function_approval_response( - id=approval_req_2.id, - function_call=approval_req_2.function_call, + id=approval_req_2.id, # type: ignore[arg-type] + function_call=approval_req_2.function_call, # type: ignore[arg-type] approved=False, ) @@ -818,7 +828,7 @@ def func_with_approval(arg1: str) -> str: exec_counter += 1 return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -829,8 +839,9 @@ def func_with_approval(arg1: str) -> str: ), ] - response = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, # type: ignore[arg-type] ) # Should have one assistant message containing both the call and approval request @@ -853,7 +864,7 @@ def func_with_approval(arg1: str) -> str: exec_counter += 1 return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -866,8 +877,9 @@ def func_with_approval(arg1: str) -> str: ] # Get approval request - response1 = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + response1 = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, # type: ignore[arg-type] ) # Store messages (like a thread would) @@ -879,8 +891,8 @@ def func_with_approval(arg1: str) -> str: # Send approval approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] + function_call=approval_req.function_call, # type: ignore[arg-type] approved=True, ) persisted_messages.append(Message(role="user", contents=[approval_response])) @@ -902,7 +914,7 @@ async def test_no_duplicate_function_calls_after_approval_processing(chat_client def func_with_approval(arg1: str) -> str: return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -914,14 +926,15 @@ def func_with_approval(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response1 = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + response1 = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, # type: ignore[arg-type] ) approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] + function_call=approval_req.function_call, # type: ignore[arg-type] approved=True, ) @@ -946,7 +959,7 @@ async def test_rejection_result_uses_function_call_id(chat_client_base: Supports def func_with_approval(arg1: str) -> str: return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -958,14 +971,15 @@ def func_with_approval(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response1 = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} + response1 = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, # type: ignore[arg-type] ) approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] rejection_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] + function_call=approval_req.function_call, # type: ignore[arg-type] approved=False, ) @@ -994,7 +1008,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}" # Set up multiple function call responses to create a loop - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1016,7 +1030,7 @@ def ai_func(arg1: str) -> str: ] # Set max_iterations to 1 in additional_properties - chat_client_base.function_invocation_configuration["max_iterations"] = 1 + chat_client_base.function_invocation_configuration["max_iterations"] = 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [ai_func]} @@ -1043,7 +1057,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}" # Model keeps requesting tool calls on every iteration - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1070,7 +1084,7 @@ def ai_func(arg1: str) -> str: ), ] - chat_client_base.function_invocation_configuration["max_iterations"] = 2 + chat_client_base.function_invocation_configuration["max_iterations"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], @@ -1105,7 +1119,7 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1126,7 +1140,7 @@ def ai_func(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["Final answer after giving up on tools."])), ] - chat_client_base.function_invocation_configuration["max_iterations"] = 1 + chat_client_base.function_invocation_configuration["max_iterations"] = 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], @@ -1165,7 +1179,7 @@ def ai_func(arg1: str) -> str: return f"Result {exec_counter}" # Two iterations of function calls, then failsafe - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1185,7 +1199,7 @@ def ai_func(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["Done"])), ] - chat_client_base.function_invocation_configuration["max_iterations"] = 2 + chat_client_base.function_invocation_configuration["max_iterations"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], @@ -1226,7 +1240,7 @@ def browser_snapshot(url: str) -> str: # The failsafe call (with tool_choice="none") after the loop is handled # automatically by the mock client, which returns a hardcoded text response # when tool_choice="none" (see conftest.py ChatClientBase.get_response). - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1249,7 +1263,7 @@ def browser_snapshot(url: str) -> str: ), ] - chat_client_base.function_invocation_configuration["max_iterations"] = 2 + chat_client_base.function_invocation_configuration["max_iterations"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] agent = Agent( client=chat_client_base, @@ -1290,7 +1304,7 @@ def search_func(query: str) -> str: return f"Result for {query}" # Each iteration returns 3 parallel tool calls - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1317,7 +1331,7 @@ def search_func(query: str) -> str: ] # Allow many iterations but cap total function calls at 5 - chat_client_base.function_invocation_configuration["max_function_calls"] = 5 + chat_client_base.function_invocation_configuration["max_function_calls"] = 5 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["search"])], options={"tool_choice": "auto", "tools": [search_func]} @@ -1341,7 +1355,7 @@ def lookup_func(key: str) -> str: exec_counter += 1 return f"Value for {key}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1370,7 +1384,7 @@ def lookup_func(key: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["all done"])), ] - chat_client_base.function_invocation_configuration["max_function_calls"] = 2 + chat_client_base.function_invocation_configuration["max_function_calls"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["look up keys"])], options={"tool_choice": "auto", "tools": [lookup_func]} @@ -1392,7 +1406,7 @@ def do_thing_func(arg: str) -> str: exec_counter += 1 return f"Done {arg}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1405,7 +1419,7 @@ def do_thing_func(arg: str) -> str: ] + [ChatResponse(messages=Message(role="assistant", contents=["finished"]))] # Explicitly set to None (default) — should not limit - chat_client_base.function_invocation_configuration["max_function_calls"] = None + chat_client_base.function_invocation_configuration["max_function_calls"] = None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["do things"])], options={"tool_choice": "auto", "tools": [do_thing_func]} @@ -1425,12 +1439,12 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse(messages=Message(role="assistant", contents=["response without function calling"])), ] # Disable function invocation - chat_client_base.function_invocation_configuration["enabled"] = False + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [ai_func]} @@ -1457,11 +1471,11 @@ async def capture_middleware(context, call_next): captured_kwargs.update(context.function_invocation_kwargs or {}) await call_next() - chat_client_base.chat_middleware = [capture_middleware] - chat_client_base.run_responses = [ + chat_client_base.chat_middleware = [capture_middleware] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse(messages=Message(role="assistant", contents=["response without function calling"])), ] - chat_client_base.function_invocation_configuration["enabled"] = False + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] await chat_client_base.get_response( [Message(role="user", contents=["hello"])], @@ -1481,7 +1495,7 @@ def error_func(arg1: str) -> str: raise ValueError("Function error") # Set up multiple function call responses that will all error - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1518,7 +1532,7 @@ def error_func(arg1: str) -> str: ] # Set max_consecutive_errors to 2 - chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2 + chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [error_func]} @@ -1549,7 +1563,7 @@ async def test_function_invocation_stop_clears_conversation_id_non_stream(chat_c def error_func(arg1: str) -> str: raise ValueError("Function error") - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1560,7 +1574,7 @@ def error_func(arg1: str) -> str: conversation_id="resp_1", ) ] - chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 1 + chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session_stub = type("SessionStub", (), {"service_session_id": "resp_seed"})() response = await chat_client_base.get_response( @@ -1582,7 +1596,7 @@ def known_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1595,7 +1609,7 @@ def known_func(arg1: str) -> str: ] # Set terminate_on_unknown_calls to False (default) - chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False + chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [known_func]} @@ -1619,7 +1633,7 @@ def known_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1631,7 +1645,7 @@ def known_func(arg1: str) -> str: ] # Set terminate_on_unknown_calls to True - chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = True + chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Should raise an exception when encountering an unknown function with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'): @@ -1659,7 +1673,7 @@ def hidden_func(arg1: str) -> str: exec_counter_hidden += 1 return f"Hidden {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1672,7 +1686,7 @@ def hidden_func(arg1: str) -> str: ] # Add hidden_func to additional_tools - chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func] + chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Only pass visible_func in the tools parameter response = await chat_client_base.get_response( @@ -1700,7 +1714,7 @@ async def test_function_invocation_config_include_detailed_errors_false(chat_cli def error_func(arg1: str) -> str: raise ValueError("Specific error message that should not appear") - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1713,7 +1727,7 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to False (default) - chat_client_base.function_invocation_configuration["include_detailed_errors"] = False + chat_client_base.function_invocation_configuration["include_detailed_errors"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [error_func]} @@ -1736,7 +1750,7 @@ async def test_function_invocation_config_include_detailed_errors_true(chat_clie def error_func(arg1: str) -> str: raise ValueError("Specific error message that should appear") - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1749,7 +1763,7 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to True - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [error_func]} @@ -1835,7 +1849,7 @@ async def test_argument_validation_error_with_detailed_errors(chat_client_base: def typed_func(arg1: int) -> str: # Expects int, not str return f"Got {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1848,7 +1862,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str ] # Set include_detailed_errors to True - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [typed_func]} @@ -1871,7 +1885,7 @@ async def test_argument_validation_error_without_detailed_errors(chat_client_bas def typed_func(arg1: int) -> str: # Expects int, not str return f"Got {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -1884,7 +1898,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str ] # Set include_detailed_errors to False (default) - chat_client_base.function_invocation_configuration["include_detailed_errors"] = False + chat_client_base.function_invocation_configuration["include_detailed_errors"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] response = await chat_client_base.get_response( [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [typed_func]} @@ -1917,7 +1931,7 @@ def local_func(arg1: str) -> str: approved=True, ) - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse(messages=Message(role="assistant", contents=["done"])), ] @@ -1958,7 +1972,7 @@ def local_func(arg1: str) -> str: mcp_approval_response = mcp_approval_request.to_function_approval_response(approved=True) # The second call (after approval) should return a final response - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse(messages=Message(role="assistant", contents=["Here are the docs results."])), ] @@ -2112,7 +2126,7 @@ def test_replace_approval_contents_with_results_skips_results_without_call_id() messages, _collect_approval_responses(messages), [ - Content.from_function_result(call_id=None, result="ignored result"), + Content.from_function_result(call_id=None, result="ignored result"), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] Content.from_function_result(call_id="call_1", result="first result"), ], ) @@ -2199,7 +2213,7 @@ def local_func(arg1: str) -> str: mcp_approval_request = Content.from_function_approval_request(id="mcpr_hosted", function_call=mcp_fc) # First response: LLM returns a local function call that needs approval - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse(messages=Message(role="assistant", contents=[local_fc])), # After local approval + hosted approval, the final response ChatResponse(messages=Message(role="assistant", contents=["Done with both tools."])), @@ -2239,7 +2253,7 @@ async def test_unapproved_tool_execution_raises_exception(chat_client_base: Supp def test_func(arg1: str) -> str: return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2260,8 +2274,8 @@ def test_func(arg1: str) -> str: # Create a rejection response (approved=False) rejection_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + function_call=approval_req.function_call, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] approved=False, ) @@ -2298,7 +2312,7 @@ def error_func(arg1: str) -> str: exec_counter += 1 raise ValueError("Specific error from approved function") - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2309,7 +2323,7 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to False (default) - chat_client_base.function_invocation_configuration["include_detailed_errors"] = False + chat_client_base.function_invocation_configuration["include_detailed_errors"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Get approval request response1 = await chat_client_base.get_response( @@ -2320,8 +2334,8 @@ def error_func(arg1: str) -> str: # Approve the function approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + function_call=approval_req.function_call, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] approved=True, ) @@ -2363,7 +2377,7 @@ def error_func(arg1: str) -> str: exec_counter += 1 raise ValueError("Specific error from approved function") - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2374,7 +2388,7 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to True - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Get approval request response1 = await chat_client_base.get_response( @@ -2385,8 +2399,8 @@ def error_func(arg1: str) -> str: # Approve the function approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + function_call=approval_req.function_call, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] approved=True, ) @@ -2426,7 +2440,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str exec_counter += 1 return f"Got {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2439,7 +2453,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str ] # Set include_detailed_errors to True to see validation details - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Get approval request response1 = await chat_client_base.get_response( @@ -2450,8 +2464,8 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Approve the function (even though it will fail validation) approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + function_call=approval_req.function_call, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] approved=True, ) @@ -2489,7 +2503,7 @@ def success_func(arg1: str) -> str: exec_counter += 1 return f"Success {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2508,8 +2522,8 @@ def success_func(arg1: str) -> str: # Approve the function approval_response = Content.from_function_approval_response( - id=approval_req.id, - function_call=approval_req.function_call, + id=approval_req.id, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + function_call=approval_req.function_call, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] approved=True, ) @@ -2550,7 +2564,7 @@ async def test_declaration_only_tool(chat_client_base: SupportsChatGetResponse): # Verify it's marked as declaration_only assert declaration_func.declaration_only is True - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2562,8 +2576,9 @@ async def test_declaration_only_tool(chat_client_base: SupportsChatGetResponse): ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response = await chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [declaration_func]} + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [declaration_func]}, # type: ignore[arg-type] ) # Should have the function call in messages but not a result @@ -2589,7 +2604,7 @@ async def test_multiple_function_calls_parallel_execution(chat_client_base: Supp """Test that multiple function calls are executed in parallel.""" import asyncio - exec_order = [] + exec_order = [] # type: ignore[var-annotated] @tool(name="func1", approval_mode="never_require") async def func1(arg1: str) -> str: @@ -2605,7 +2620,7 @@ async def func2(arg1: str) -> str: exec_order.append("func2_end") return f"Result2 {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2644,7 +2659,7 @@ def plain_function(arg1: str) -> str: exec_counter += 1 return f"Plain {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2675,7 +2690,7 @@ def test_func(arg1: str) -> str: return f"Result {arg1}" # Return a response with a conversation_id - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2708,7 +2723,7 @@ async def test_function_result_appended_to_existing_assistant_message(chat_clien def test_func(arg1: str) -> str: return f"Result {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2747,7 +2762,7 @@ def sometimes_fails(arg1: str) -> str: raise ValueError("First call fails") return f"Success {arg1}" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -2804,7 +2819,7 @@ def func_with_approval(arg1: str) -> str: return f"Result {arg1}" # Setup: function call that requires approval, streamed - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}')], @@ -2815,8 +2830,10 @@ def func_with_approval(arg1: str) -> str: # Get the streaming response with approval request updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -2825,7 +2842,7 @@ def func_with_approval(arg1: str) -> str: content for update in updates for content in update.contents if content.type == "function_approval_request" ] assert len(approval_requests) == 1 - assert approval_requests[0].function_call.name == "test_func" + assert approval_requests[0].function_call.name == "test_func" # type: ignore[union-attr] assert exec_counter == 0 # Function not executed yet due to approval requirement @@ -2840,7 +2857,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}" # Set up multiple function call responses to create a loop - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1":')], @@ -2866,11 +2883,13 @@ def ai_func(arg1: str) -> str: ] # Set max_iterations to 1 in additional_properties - chat_client_base.function_invocation_configuration["max_iterations"] = 1 + chat_client_base.function_invocation_configuration["max_iterations"] = 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [ai_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [ai_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -2891,16 +2910,18 @@ def ai_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ChatResponseUpdate(contents=[Content.from_text(text="response without function calling")], role="assistant")], ] # Disable function invocation - chat_client_base.function_invocation_configuration["enabled"] = False + chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [ai_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [ai_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -2918,7 +2939,7 @@ def error_func(arg1: str) -> str: raise ValueError("Function error") # Set up multiple function call responses that will all error - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -2947,11 +2968,13 @@ def error_func(arg1: str) -> str: ] # Set max_consecutive_errors to 2 - chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2 + chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [error_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [error_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -2977,7 +3000,7 @@ async def test_streaming_function_invocation_stop_clears_conversation_id(chat_cl def error_func(arg1: str) -> str: raise ValueError("Function error") - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -2988,11 +3011,11 @@ def error_func(arg1: str) -> str: ) ] ] - chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 1 + chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] session_stub = type("SessionStub", (), {"service_session_id": "resp_seed"})() - stream = chat_client_base.get_response( - "hello", + stream = chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] options={"tool_choice": "auto", "tools": [error_func]}, stream=True, client_kwargs={"session": session_stub}, @@ -3019,7 +3042,7 @@ def known_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3032,11 +3055,13 @@ def known_func(arg1: str) -> str: ] # Set terminate_on_unknown_calls to False (default) - chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False + chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [known_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [known_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3063,7 +3088,7 @@ def known_func(arg1: str) -> str: exec_counter += 1 return f"Processed {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3075,11 +3100,11 @@ def known_func(arg1: str) -> str: ] # Set terminate_on_unknown_calls to True - chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = True + chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Should raise an exception when encountering an unknown function with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'): - async for _ in chat_client_base.get_response( + async for _ in chat_client_base.get_response( # type: ignore[attr-defined] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] [Message(role="user", contents=["hello"])], options={"tool_choice": "auto", "tools": [known_func]} ): pass @@ -3096,7 +3121,7 @@ async def test_streaming_function_invocation_config_include_detailed_errors_true def error_func(arg1: str) -> str: raise ValueError("Specific error message that should appear") - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3109,11 +3134,13 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to True - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [error_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [error_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3136,7 +3163,7 @@ async def test_streaming_function_invocation_config_include_detailed_errors_fals def error_func(arg1: str) -> str: raise ValueError("Specific error message that should not appear") - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3149,11 +3176,13 @@ def error_func(arg1: str) -> str: ] # Set include_detailed_errors to False (default) - chat_client_base.function_invocation_configuration["include_detailed_errors"] = False + chat_client_base.function_invocation_configuration["include_detailed_errors"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [error_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [error_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3174,7 +3203,7 @@ async def test_streaming_argument_validation_error_with_detailed_errors(chat_cli def typed_func(arg1: int) -> str: # Expects int, not str return f"Got {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3187,11 +3216,13 @@ def typed_func(arg1: int) -> str: # Expects int, not str ] # Set include_detailed_errors to True - chat_client_base.function_invocation_configuration["include_detailed_errors"] = True + chat_client_base.function_invocation_configuration["include_detailed_errors"] = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [typed_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [typed_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3212,7 +3243,7 @@ async def test_streaming_argument_validation_error_without_detailed_errors(chat_ def typed_func(arg1: int) -> str: # Expects int, not str return f"Got {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3225,11 +3256,13 @@ def typed_func(arg1: int) -> str: # Expects int, not str ] # Set include_detailed_errors to False (default) - chat_client_base.function_invocation_configuration["include_detailed_errors"] = False + chat_client_base.function_invocation_configuration["include_detailed_errors"] = False # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [typed_func]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [typed_func]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3246,7 +3279,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str async def test_streaming_multiple_function_calls_parallel_execution(chat_client_base: SupportsChatGetResponse): """Test that multiple function calls are executed in parallel in streaming mode.""" - exec_order = [] + exec_order = [] # type: ignore[var-annotated] @tool(name="func1", approval_mode="never_require") async def func1(arg1: str) -> str: @@ -3262,7 +3295,7 @@ async def func2(arg1: str) -> str: exec_order.append("func2_end") return f"Result2 {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name="func1", arguments='{"arg1": "value1"}')], @@ -3277,8 +3310,10 @@ async def func2(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func1, func2]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func1, func2]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3303,7 +3338,7 @@ def func_with_approval(arg1: str) -> str: exec_counter += 1 return f"Result {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3315,8 +3350,10 @@ def func_with_approval(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [func_with_approval]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [func_with_approval]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3341,7 +3378,7 @@ def sometimes_fails(arg1: str) -> str: raise ValueError("First call fails") return f"Success {arg1}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3362,8 +3399,10 @@ def sometimes_fails(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_response( - "hello", options={"tool_choice": "auto", "tools": [sometimes_fails]}, stream=True + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [sometimes_fails]}, + stream=True, # type: ignore[arg-type] ): updates.append(update) @@ -3389,7 +3428,7 @@ def sometimes_fails(arg1: str) -> str: class TerminateLoopMiddleware(FunctionMiddleware): """Middleware that raises MiddlewareTermination to exit the function calling loop.""" - async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: # pyrefly: ignore[bad-override-param-name] # ty: ignore[invalid-method-override] # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" raise MiddlewareTermination @@ -3407,7 +3446,7 @@ def ai_func(arg1: str) -> str: # Queue up two responses: function call, then final text # If terminate_loop works, only the first response should be consumed - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -3419,8 +3458,8 @@ def ai_func(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response = await chat_client_base.get_response( - "hello", + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] options={"tool_choice": "auto", "tools": [ai_func]}, client_kwargs={"middleware": [TerminateLoopMiddleware()]}, ) @@ -3438,13 +3477,13 @@ def ai_func(arg1: str) -> str: assert response.messages[1].contents[0].result == "terminated by middleware" # Verify the second response is still in the queue (wasn't consumed) - assert len(chat_client_base.run_responses) == 1 + assert len(chat_client_base.run_responses) == 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] class SelectiveTerminateMiddleware(FunctionMiddleware): """Only terminates for terminating_function.""" - async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: + async def process(self, context: FunctionInvocationContext, next_handler: Callable[[], Awaitable[None]]) -> None: # pyrefly: ignore[bad-override-param-name] # ty: ignore[invalid-method-override] if context.function.name == "terminating_function": # Set result to a simple value - the framework will wrap it in FunctionResultContent context.result = "terminated by middleware" @@ -3470,7 +3509,7 @@ def terminating_func(arg1: str) -> str: return f"Terminating {arg1}" # Queue up two responses: parallel function calls, then final text - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -3485,8 +3524,8 @@ def terminating_func(arg1: str) -> str: ChatResponse(messages=Message(role="assistant", contents=["done"])), ] - response = await chat_client_base.get_response( - "hello", + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] options={"tool_choice": "auto", "tools": [normal_func, terminating_func]}, client_kwargs={"middleware": [SelectiveTerminateMiddleware()]}, ) @@ -3506,7 +3545,7 @@ def terminating_func(arg1: str) -> str: assert len(response.messages[1].contents) == 2 # Verify the second response is still in the queue (wasn't consumed) - assert len(chat_client_base.run_responses) == 1 + assert len(chat_client_base.run_responses) == 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_terminate_loop_streaming_single_function_call(chat_client_base: SupportsChatGetResponse): @@ -3520,7 +3559,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}" # Queue up two streaming responses - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[ @@ -3538,8 +3577,8 @@ def ai_func(arg1: str) -> str: ] updates = [] - async for update in chat_client_base.get_response( - "hello", + async for update in chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] options={"tool_choice": "auto", "tools": [ai_func]}, client_kwargs={"middleware": [TerminateLoopMiddleware()]}, stream=True, @@ -3554,7 +3593,7 @@ def ai_func(arg1: str) -> str: assert len(updates) == 2 # Verify the second streaming response is still in the queue (wasn't consumed) - assert len(chat_client_base.streaming_responses) == 1 + assert len(chat_client_base.streaming_responses) == 1 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_conversation_id_updated_in_options_between_tool_iterations(): @@ -3595,12 +3634,12 @@ def __init__(self) -> None: self.streaming_responses: list[list[ChatResponseUpdate]] = [] self.call_count: int = 0 - def _inner_get_response( + def _inner_get_response( # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self, *, - messages: MutableSequence[Message], + messages: MutableSequence[Message], # type: ignore[override] stream: bool, - options: dict[str, Any], + options: dict[str, Any], # type: ignore[override] **kwargs: Any, ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: # Track what conversation_id was passed @@ -3665,8 +3704,8 @@ def test_func(arg1: str) -> str: ] # Start with initial conversation_id - await client.get_response( - "hello", + await client.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] options={"tool_choice": "auto", "tools": [test_func], "conversation_id": "conv_initial"}, ) @@ -3697,8 +3736,8 @@ def test_func(arg1: str) -> str: ], ] - response_stream = streaming_client.get_response( - "hello", + response_stream = streaming_client.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "hello", # type: ignore[arg-type] stream=True, options={"tool_choice": "auto", "tools": [test_func], "conversation_id": "stream_conv_initial"}, ) @@ -3730,7 +3769,7 @@ async def test_streaming_function_calling_response_includes_reasoning_and_tool_r def search_func(query: str) -> str: return f"Found results for {query}" - chat_client_base.streaming_responses = [ + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ # First response: reasoning + function_call ChatResponseUpdate( @@ -3764,8 +3803,10 @@ def search_func(query: str) -> str: ], ] - stream = chat_client_base.get_response( - "search for test", options={"tool_choice": "auto", "tools": [search_func]}, stream=True + stream = chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + "search for test", # type: ignore[arg-type] + options={"tool_choice": "auto", "tools": [search_func]}, + stream=True, # type: ignore[arg-type] ) updates = [] @@ -3862,7 +3903,7 @@ def delegate_tool(task: str) -> str: ] ) - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -3911,7 +3952,7 @@ def multi_request(task: str) -> str: ] ) - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -3951,7 +3992,7 @@ def empty_request(task: str) -> str: del task raise UserInputRequiredException(contents=[]) - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ChatResponse( messages=Message( role="assistant", @@ -4013,7 +4054,7 @@ def inspect_tools(ctx: FunctionInvocationContext) -> str: seen_names.extend(t.name for t in ctx.tools if isinstance(t, FunctionTool)) return "inspected" - chat_client_base.run_responses = [ + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "inspect_tools"), _pte_text_response(), ] @@ -4038,8 +4079,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(factorial) return "math tools loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_function_call_response("2", "factorial", '{"n": 5}'), _pte_text_response(), @@ -4057,10 +4098,10 @@ async def test_add_tools_model_sees_added_tools_in_options(chat_client_base: Sup recorded: list[list[str]] = [] client_cls = type(chat_client_base) - original = client_cls._get_non_streaming_response + original = client_cls._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def recording(self: Any, *, messages: Any, options: dict[str, Any], **kwargs: Any) -> ChatResponse: - tools = options.get("tools") or [] + tools = options.get("tools") or [] # type: ignore[var-annotated] recorded.append([t.name for t in tools if isinstance(t, FunctionTool)]) return await original(self, messages=messages, options=options, **kwargs) @@ -4069,8 +4110,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(_pte_factorial) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_function_call_response("2", "factorial", '{"n": 5}'), _pte_text_response(), @@ -4094,10 +4135,10 @@ async def test_remove_tools_next_iteration(chat_client_base: SupportsChatGetResp recorded: list[list[str]] = [] client_cls = type(chat_client_base) - original = client_cls._get_non_streaming_response + original = client_cls._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def recording(self: Any, *, messages: Any, options: dict[str, Any], **kwargs: Any) -> ChatResponse: - tools = options.get("tools") or [] + tools = options.get("tools") or [] # type: ignore[var-annotated] recorded.append([t.name for t in tools if isinstance(t, FunctionTool)]) return await original(self, messages=messages, options=options, **kwargs) @@ -4110,8 +4151,8 @@ def drop_weather(ctx: FunctionInvocationContext) -> str: ctx.remove_tools("get_weather") return "removed" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "drop_weather"), _pte_text_response(), ] @@ -4136,8 +4177,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: return "loaded" original_tools: list[Any] = [load_math] - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_text_response(), ] @@ -4153,10 +4194,10 @@ async def test_add_tools_persists_across_iterations(chat_client_base: SupportsCh recorded: list[list[str]] = [] client_cls = type(chat_client_base) - original = client_cls._get_non_streaming_response + original = client_cls._get_non_streaming_response # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def recording(self: Any, *, messages: Any, options: dict[str, Any], **kwargs: Any) -> ChatResponse: - tools = options.get("tools") or [] + tools = options.get("tools") or [] # type: ignore[var-annotated] recorded.append([t.name for t in tools if isinstance(t, FunctionTool)]) return await original(self, messages=messages, options=options, **kwargs) @@ -4165,8 +4206,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(_pte_factorial) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 4 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 4 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_function_call_response("2", "factorial", '{"n": 5}'), _pte_function_call_response("3", "factorial", '{"n": 3}'), @@ -4204,13 +4245,13 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(factorial) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_function_call_response("2", "factorial", '{"n": 5}'), _pte_text_response(), ] - await chat_client_base.get_response( + await chat_client_base.get_response( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] [Message(role="user", contents=["hi"])], options={"tool_choice": "auto", "tools": [load_math]}, middleware=[PassthroughMiddleware()], @@ -4228,8 +4269,8 @@ def load_secure(ctx: FunctionInvocationContext) -> str: ctx.add_tools(secure_tool) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_secure"), _pte_function_call_response("2", "secure_tool", '{"value": "x"}'), _pte_text_response(), @@ -4255,8 +4296,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(plain_factorial) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.run_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.run_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] _pte_function_call_response("1", "load_math"), _pte_function_call_response("2", "plain_factorial", '{"n": 5}'), _pte_text_response(), @@ -4282,8 +4323,8 @@ def load_math(ctx: FunctionInvocationContext) -> str: ctx.add_tools(factorial) return "loaded" - chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] - chat_client_base.streaming_responses = [ + chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + chat_client_base.streaming_responses = [ # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] [ ChatResponseUpdate( contents=[Content.from_function_call(call_id="1", name="load_math", arguments="{}")], @@ -4367,7 +4408,7 @@ def b(x: int) -> int: ctx = FunctionInvocationContext(function=a, arguments={}, tools=[a, b]) ctx.remove_tools("a") assert ctx.tools is not None - assert [t.name for t in ctx.tools] == ["b"] + assert [t.name for t in ctx.tools] == ["b"] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] ctx.remove_tools(b) assert ctx.tools == [] @@ -4380,7 +4421,7 @@ def a(x: int) -> int: ctx = FunctionInvocationContext(function=a, arguments={}, tools=[a]) ctx.remove_tools("nonexistent") assert ctx.tools is not None - assert [t.name for t in ctx.tools] == ["a"] + assert [t.name for t in ctx.tools] == ["a"] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] def test_progressive_tools_helpers_raise_without_live_tools(): diff --git a/python/packages/core/tests/core/test_harness_agent.py b/python/packages/core/tests/core/test_harness_agent.py index 58ef3f5f2d0..ea8a252c028 100644 --- a/python/packages/core/tests/core/test_harness_agent.py +++ b/python/packages/core/tests/core/test_harness_agent.py @@ -52,7 +52,7 @@ async def get_streaming_response( def test_create_harness_agent_with_defaults() -> None: """create_harness_agent should assemble successfully with default options.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -62,7 +62,7 @@ def test_create_harness_agent_with_defaults() -> None: def test_create_harness_agent_includes_all_default_providers() -> None: """Default assembly should include history, compaction, todo, mode (no skills by default).""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -79,7 +79,7 @@ def test_create_harness_agent_includes_all_default_providers() -> None: def test_create_harness_agent_disable_todo() -> None: """disable_todo=True should exclude TodoProvider.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_todo=True, @@ -91,7 +91,7 @@ def test_create_harness_agent_disable_todo() -> None: def test_create_harness_agent_disable_mode() -> None: """disable_mode=True should exclude AgentModeProvider.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_mode=True, @@ -118,19 +118,19 @@ def write_topic(self, session, record, *, source_id): def delete_topic(self, session, *, source_id, topic): pass - def get_index_text(self, session, *, source_id): + def get_index_text(self, session, *, source_id): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return "" - def get_transcripts_directory(self, session, *, source_id): + def get_transcripts_directory(self, session, *, source_id): # pyrefly: ignore[bad-override] return "" def read_state(self, session, *, source_id): return {} - def rebuild_index(self, session, *, source_id): + def rebuild_index(self, session, *, source_id): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] pass - def search_transcripts(self, session, *, source_id, query): + def search_transcripts(self, session, *, source_id, query): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return [] def write_state(self, session, state, *, source_id): @@ -138,7 +138,7 @@ def write_state(self, session, state, *, source_id): # With memory_store provided and disable_memory=False, MemoryContextProvider should be present. agent_with_memory = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, memory_store=_FakeMemoryStore(), @@ -148,7 +148,7 @@ def write_state(self, session, state, *, source_id): # With memory_store provided and disable_memory=True, MemoryContextProvider should be absent. agent_disabled = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, memory_store=_FakeMemoryStore(), @@ -161,7 +161,7 @@ def write_state(self, session, state, *, source_id): def test_create_harness_agent_skills_paths_adds_provider() -> None: """skills_paths should add a SkillsProvider.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, skills_paths=["./test-skills"], @@ -173,7 +173,7 @@ def test_create_harness_agent_skills_paths_adds_provider() -> None: def test_create_harness_agent_disable_compaction() -> None: """disable_compaction=True should exclude CompactionProvider.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_compaction=True, @@ -187,7 +187,7 @@ def test_create_harness_agent_returns_full_agent() -> None: from agent_framework._agents import Agent as FullAgent agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -201,7 +201,7 @@ def test_create_harness_agent_rejects_invalid_context_tokens() -> None: """max_context_window_tokens must be positive.""" with pytest.raises(ValueError, match="max_context_window_tokens must be positive"): create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=0, max_output_tokens=100, ) @@ -211,7 +211,7 @@ def test_create_harness_agent_rejects_negative_output_tokens() -> None: """max_output_tokens must be non-negative.""" with pytest.raises(ValueError, match="max_output_tokens must be non-negative"): create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=1000, max_output_tokens=-1, ) @@ -221,7 +221,7 @@ def test_create_harness_agent_rejects_output_gte_context() -> None: """max_output_tokens must be less than max_context_window_tokens.""" with pytest.raises(ValueError, match="max_output_tokens must be less than"): create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=1000, max_output_tokens=1000, ) @@ -239,8 +239,8 @@ def test_default_instructions() -> None: def test_custom_agent_instructions_appended() -> None: """Agent instructions should be appended after harness instructions.""" result = _assemble_instructions(None, "Focus on code review.") - assert DEFAULT_HARNESS_INSTRUCTIONS in result # type: ignore[operator] - assert "Focus on code review." in result # type: ignore[operator] + assert DEFAULT_HARNESS_INSTRUCTIONS in result # type: ignore[operator] # ty: ignore[unsupported-operator] + assert "Focus on code review." in result # type: ignore[operator] # ty: ignore[unsupported-operator] def test_empty_harness_instructions_uses_agent_only() -> None: @@ -255,7 +255,7 @@ def test_empty_harness_instructions_uses_agent_only() -> None: def test_create_harness_agent_custom_identity() -> None: """Custom id, name, description should propagate.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, id="my-agent-id", @@ -273,7 +273,7 @@ def test_create_harness_agent_custom_identity() -> None: def test_create_harness_agent_create_session() -> None: """create_session should return an AgentSession.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -284,7 +284,7 @@ def test_create_harness_agent_create_session() -> None: def test_create_harness_agent_create_session_with_id() -> None: """create_session should accept a custom session_id.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -295,7 +295,7 @@ def test_create_harness_agent_create_session_with_id() -> None: async def test_create_harness_agent_run_returns_response() -> None: """agent.run() should return a response.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -313,7 +313,7 @@ def test_create_harness_agent_satisfies_protocol() -> None: from agent_framework import SupportsAgentRun agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -331,7 +331,7 @@ class _CustomProvider(ContextProvider): custom = _CustomProvider("custom") agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, context_providers=[custom], @@ -352,7 +352,7 @@ def get_web_search_tool(self, **kwargs: Any) -> str: def test_create_harness_agent_auto_adds_web_search_tool() -> None: """Web search tool should be auto-added when client supports it.""" agent = create_harness_agent( - client=_FakeWebSearchClient(), # type: ignore[arg-type] + client=_FakeWebSearchClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -363,7 +363,7 @@ def test_create_harness_agent_auto_adds_web_search_tool() -> None: def test_create_harness_agent_disable_web_search() -> None: """disable_web_search=True should skip auto-adding the web search tool.""" agent = create_harness_agent( - client=_FakeWebSearchClient(), # type: ignore[arg-type] + client=_FakeWebSearchClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_web_search=True, @@ -375,7 +375,7 @@ def test_create_harness_agent_disable_web_search() -> None: def test_create_harness_agent_no_web_search_when_unsupported() -> None: """Web search tool should NOT be added when client does not support it.""" agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -389,7 +389,7 @@ def test_create_harness_agent_logs_warning_when_no_web_search(caplog: pytest.Log with caplog.at_level(logging.WARNING, logger="agent_framework._harness._agent"): create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, ) @@ -424,7 +424,7 @@ def test_create_harness_agent_no_background_agents_by_default() -> None: from agent_framework._harness._background_agents import BackgroundAgentsProvider agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_web_search=True, @@ -439,11 +439,11 @@ def test_create_harness_agent_adds_background_agents_provider() -> None: bg_agent = _FakeBackgroundAgent("WebSearcher", "Searches the web") agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_web_search=True, - background_agents=[bg_agent], + background_agents=[bg_agent], # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] ) providers = agent.context_providers or [] bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)] @@ -457,11 +457,11 @@ def test_create_harness_agent_background_agents_custom_instructions() -> None: custom_instructions = "## Custom\n\nUse agents wisely.\n\n{background_agents}" bg_agent = _FakeBackgroundAgent("Helper", "A helper agent") agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_web_search=True, - background_agents=[bg_agent], + background_agents=[bg_agent], # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] background_agents_instructions=custom_instructions, ) providers = agent.context_providers or [] @@ -477,7 +477,7 @@ def test_create_harness_agent_empty_background_agents_list() -> None: from agent_framework._harness._background_agents import BackgroundAgentsProvider agent = create_harness_agent( - client=_FakeChatClient(), # type: ignore[arg-type] + client=_FakeChatClient(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] max_context_window_tokens=128_000, max_output_tokens=16_384, disable_web_search=True, diff --git a/python/packages/core/tests/core/test_harness_background_agents.py b/python/packages/core/tests/core/test_harness_background_agents.py index 34f0893df92..0c2b97f3f67 100644 --- a/python/packages/core/tests/core/test_harness_background_agents.py +++ b/python/packages/core/tests/core/test_harness_background_agents.py @@ -61,7 +61,7 @@ async def run( def _make_provider(*agents: _FakeAgent) -> BackgroundAgentsProvider: """Create a provider with given agents.""" - return BackgroundAgentsProvider(agents) + return BackgroundAgentsProvider(agents) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] def _make_session() -> AgentSession: @@ -97,7 +97,7 @@ def test_constructor_requires_agent_names() -> None: """Should reject agents with no name.""" agent = _FakeAgent("") with pytest.raises(ValueError, match="non-empty name"): - BackgroundAgentsProvider([agent]) + BackgroundAgentsProvider([agent]) # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] def test_constructor_rejects_duplicate_names() -> None: @@ -105,18 +105,18 @@ def test_constructor_rejects_duplicate_names() -> None: agent1 = _FakeAgent("Research") agent2 = _FakeAgent("research") with pytest.raises(ValueError, match="Duplicate background agent name"): - BackgroundAgentsProvider([agent1, agent2]) + BackgroundAgentsProvider([agent1, agent2]) # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] def test_constructor_valid_agents() -> None: """Should succeed with valid unique agents.""" - provider = BackgroundAgentsProvider([_FakeAgent("Alpha"), _FakeAgent("Beta")]) + provider = BackgroundAgentsProvider([_FakeAgent("Alpha"), _FakeAgent("Beta")]) # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert provider.source_id == "background_agents" def test_constructor_custom_source_id() -> None: """Should accept custom source_id.""" - provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg") + provider = BackgroundAgentsProvider([_FakeAgent("Agent1")], source_id="custom_bg") # type: ignore[list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert provider.source_id == "custom_bg" diff --git a/python/packages/core/tests/core/test_harness_file_access.py b/python/packages/core/tests/core/test_harness_file_access.py index c9599ccbf3c..3aa6abf577f 100644 --- a/python/packages/core/tests/core/test_harness_file_access.py +++ b/python/packages/core/tests/core/test_harness_file_access.py @@ -309,7 +309,7 @@ async def test_file_access_provider_registers_tools_and_instructions( provider = FileAccessProvider(store=store) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["work with files"])], ) @@ -340,7 +340,7 @@ async def test_file_access_provider_delete_approval_defaults_to_always_require( provider = FileAccessProvider(store=InMemoryAgentFileStore()) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["work with files"])], ) @@ -348,7 +348,7 @@ async def test_file_access_provider_delete_approval_defaults_to_always_require( tools = options["tools"] assert isinstance(tools, list) delete_file = _tool_by_name(tools, "file_access_delete_file") - assert delete_file.approval_mode == "always_require" + assert delete_file.approval_mode == "always_require" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # The non-destructive tools should remain autonomous. for name in ( "file_access_save_file", @@ -356,7 +356,7 @@ async def test_file_access_provider_delete_approval_defaults_to_always_require( "file_access_list_files", "file_access_search_files", ): - assert _tool_by_name(tools, name).approval_mode == "never_require" + assert _tool_by_name(tools, name).approval_mode == "never_require" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_file_access_provider_delete_approval_opt_out( @@ -367,13 +367,13 @@ async def test_file_access_provider_delete_approval_opt_out( provider = FileAccessProvider(store=InMemoryAgentFileStore(), require_delete_approval=False) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["work with files"])], ) delete_file = _tool_by_name(options["tools"], "file_access_delete_file") # type: ignore[arg-type] - assert delete_file.approval_mode == "never_require" + assert delete_file.approval_mode == "never_require" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_file_access_provider_tools_round_trip_files( @@ -385,7 +385,7 @@ async def test_file_access_provider_tools_round_trip_files( provider = FileAccessProvider(store=store) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["work with files"])], ) @@ -398,55 +398,55 @@ async def test_file_access_provider_tools_round_trip_files( list_files = _tool_by_name(tools, "file_access_list_files") search_files = _tool_by_name(tools, "file_access_search_files") - saved = await save_file.invoke(arguments={"file_name": "plan.md", "content": "step 1\nERROR step 2"}) + saved = await save_file.invoke(arguments={"file_name": "plan.md", "content": "step 1\nERROR step 2"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "plan.md" in saved[0].text and "saved" in saved[0].text # Default overwrite=False should refuse the second save. - refused = await save_file.invoke(arguments={"file_name": "plan.md", "content": "stomp"}) + refused = await save_file.invoke(arguments={"file_name": "plan.md", "content": "stomp"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "already exists" in refused[0].text # overwrite=True should succeed. - overwritten = await save_file.invoke( + overwritten = await save_file.invoke( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] arguments={"file_name": "plan.md", "content": "stomp\nERROR replaced", "overwrite": True} ) assert "saved" in overwritten[0].text - read_back = await read_file.invoke(arguments={"file_name": "plan.md"}) + read_back = await read_file.invoke(arguments={"file_name": "plan.md"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert read_back[0].text == "stomp\nERROR replaced" - listed = await list_files.invoke() + listed = await list_files.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(listed[0].text) == ["plan.md"] # The list tool should accept an optional directory argument so agents can # enumerate nested folders (not only the root). - await save_file.invoke(arguments={"file_name": "reports/2024.md", "content": "annual"}) - listed_nested = await list_files.invoke(arguments={"directory": "reports"}) + await save_file.invoke(arguments={"file_name": "reports/2024.md", "content": "annual"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + listed_nested = await list_files.invoke(arguments={"directory": "reports"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(listed_nested[0].text) == ["2024.md"] # Blank / whitespace directory should fall back to the root listing. - listed_blank = await list_files.invoke(arguments={"directory": " "}) + listed_blank = await list_files.invoke(arguments={"directory": " "}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert sorted(json.loads(listed_blank[0].text)) == ["plan.md"] - missing = await read_file.invoke(arguments={"file_name": "missing.md"}) + missing = await read_file.invoke(arguments={"file_name": "missing.md"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "not found" in missing[0].text - search_payload = await search_files.invoke(arguments={"regex_pattern": "error", "file_pattern": "*.md"}) + search_payload = await search_files.invoke(arguments={"regex_pattern": "error", "file_pattern": "*.md"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] parsed = json.loads(search_payload[0].text) assert parsed[0]["file_name"] == "plan.md" assert parsed[0]["matching_lines"][0]["line"] == "ERROR replaced" # The search tool should likewise accept an optional directory argument so # agents can scope a search to a subfolder. - await save_file.invoke(arguments={"file_name": "reports/issues.md", "content": "ERROR nested"}) - scoped = await search_files.invoke( + await save_file.invoke(arguments={"file_name": "reports/issues.md", "content": "ERROR nested"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + scoped = await search_files.invoke( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] arguments={"regex_pattern": "error", "file_pattern": "*.md", "directory": "reports"} ) scoped_parsed = json.loads(scoped[0].text) assert [entry["file_name"] for entry in scoped_parsed] == ["issues.md"] - deleted = await delete_file.invoke(arguments={"file_name": "plan.md"}) + deleted = await delete_file.invoke(arguments={"file_name": "plan.md"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "deleted" in deleted[0].text - missing_delete = await delete_file.invoke(arguments={"file_name": "plan.md"}) + missing_delete = await delete_file.invoke(arguments={"file_name": "plan.md"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "not found" in missing_delete[0].text @@ -518,12 +518,12 @@ def boom(self: Path) -> bool: def test_file_access_harness_classes_are_marked_experimental() -> None: """File-access harness public classes should expose HARNESS experimental metadata.""" - assert AgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert InMemoryAgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert FileSystemAgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert FileSearchMatch.__feature_id__ == ExperimentalFeature.HARNESS.value - assert FileSearchResult.__feature_id__ == ExperimentalFeature.HARNESS.value - assert FileAccessProvider.__feature_id__ == ExperimentalFeature.HARNESS.value + assert AgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert InMemoryAgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert FileSystemAgentFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert FileSearchMatch.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert FileSearchResult.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert FileAccessProvider.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert ".. warning:: Experimental" in (FileAccessProvider.__doc__ or "") @@ -596,7 +596,7 @@ async def test_file_access_tool_wrappers_surface_value_error_as_message( provider = FileAccessProvider(store=store) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["work with files"])], ) @@ -610,18 +610,18 @@ async def test_file_access_tool_wrappers_surface_value_error_as_message( search_files = _tool_by_name(tools, "file_access_search_files") # Path-traversal attempts on each tool should return a clean string, not raise. - saved = await save_file.invoke(arguments={"file_name": "../escape.txt", "content": "x"}) + saved = await save_file.invoke(arguments={"file_name": "../escape.txt", "content": "x"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not save" in saved[0].text and "escape" in saved[0].text.lower() - read = await read_file.invoke(arguments={"file_name": "../escape.txt"}) + read = await read_file.invoke(arguments={"file_name": "../escape.txt"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not read" in read[0].text - deleted = await delete_file.invoke(arguments={"file_name": "../escape.txt"}) + deleted = await delete_file.invoke(arguments={"file_name": "../escape.txt"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not delete" in deleted[0].text - listed = await list_files.invoke(arguments={"directory": "../escape"}) + listed = await list_files.invoke(arguments={"directory": "../escape"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not list" in listed[0].text # Regex length cap should also be returned to the model as text. too_long = "a" * 1024 - searched = await search_files.invoke(arguments={"regex_pattern": too_long}) + searched = await search_files.invoke(arguments={"regex_pattern": too_long}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not search files" in searched[0].text @@ -636,12 +636,12 @@ async def test_file_access_tool_read_file_wrapper_surfaces_non_utf8( provider = FileAccessProvider(store=store) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["read it"])], ) read_file = _tool_by_name(options["tools"], "file_access_read_file") - response = await read_file.invoke(arguments={"file_name": "blob.bin"}) + response = await read_file.invoke(arguments={"file_name": "blob.bin"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "Could not read" in response[0].text and "UTF-8" in response[0].text diff --git a/python/packages/core/tests/core/test_harness_memory.py b/python/packages/core/tests/core/test_harness_memory.py index 9d4c7c71d15..00a74856cea 100644 --- a/python/packages/core/tests/core/test_harness_memory.py +++ b/python/packages/core/tests/core/test_harness_memory.py @@ -14,6 +14,7 @@ DEFAULT_MEMORY_SOURCE_ID, Agent, AgentSession, + ChatOptions, ChatResponse, Content, ExperimentalFeature, @@ -27,6 +28,10 @@ ) +def _no_store_options() -> ChatOptions: + return {"store": False} + + def _tool_by_name(tools: list[object], name: str) -> object: """Return the tool with the requested name from a prepared tool list.""" for tool in tools: @@ -120,7 +125,7 @@ def test_memory_topic_record_round_trips_through_dict_and_markdown() -> None: record = MemoryTopicRecord.from_dict(raw_record) reparsed_record = MemoryTopicRecord.from_markdown(record.to_markdown()) - assert record == MemoryTopicRecord(**raw_record) + assert record == MemoryTopicRecord(**raw_record) # type: ignore[arg-type] assert record.to_dict() == raw_record assert reparsed_record == record assert "MemoryTopicRecord(" in repr(record) @@ -294,19 +299,19 @@ async def test_memory_context_provider_does_not_rewrite_unchanged_index(tmp_path session.state["owner_id"] = "alice" store = MemoryFileStore(tmp_path, owner_state_key="owner_id") agent = Agent( - client=_MemoryHarnessClient(), + client=_MemoryHarnessClient(), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[MemoryContextProvider(store=store)], - default_options={"store": False}, + default_options=_no_store_options(), ) - await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Current question"])], ) index_path = next(tmp_path.rglob("MEMORY.md")) first_mtime_ns = index_path.stat().st_mtime_ns - await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Current question"])], ) @@ -332,12 +337,12 @@ async def test_memory_context_provider_tools_and_automation(tmp_path) -> None: consolidation_interval=timedelta(0), ) agent = Agent( - client=_MemoryHarnessClient(), + client=_MemoryHarnessClient(), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[provider], - default_options={"store": False}, + default_options=_no_store_options(), ) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Remember this."])], ) @@ -349,11 +354,11 @@ async def test_memory_context_provider_tools_and_automation(tmp_path) -> None: search_memory_transcripts = _tool_by_name(tools, "search_memory_transcripts") consolidate_memories = _tool_by_name(tools, "consolidate_memories") - write_result = await write_memory.invoke(arguments={"topic": "travel", "memory": "Visit Oslo in June."}) + write_result = await write_memory.invoke(arguments={"topic": "travel", "memory": "Visit Oslo in June."}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] created_topic = json.loads(write_result[0].text) assert created_topic["topic"] == "travel" - list_result = await list_memory_topics.invoke() + list_result = await list_memory_topics.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert [entry["topic"] for entry in json.loads(list_result[0].text)] == ["travel"] await agent.run("Please remember that I prefer concise answers.", session=session) @@ -365,12 +370,12 @@ async def test_memory_context_provider_tools_and_automation(tmp_path) -> None: assert preferences_topic.summary == "Prefers concise answers." assert preferences_topic.memories == ["Prefers concise answers."] - transcript_search_result = await search_memory_transcripts.invoke(arguments={"query": "concise", "limit": 5}) + transcript_search_result = await search_memory_transcripts.invoke(arguments={"query": "concise", "limit": 5}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] search_payload = json.loads(transcript_search_result[0].text) assert search_payload[0]["role"] == "user" assert "concise answers" in search_payload[0]["text"] - consolidate_result = await consolidate_memories.invoke() + consolidate_result = await consolidate_memories.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(consolidate_result[0].text)["consolidated_topics"] >= 1 @@ -401,12 +406,12 @@ async def test_memory_context_provider_injects_recent_turns(tmp_path) -> None: state=provider_state, ) agent = Agent( - client=_MemoryHarnessClient(), + client=_MemoryHarnessClient(), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[provider], - default_options={"store": False}, + default_options=_no_store_options(), ) - session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + session_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Current question"])], ) @@ -457,21 +462,21 @@ async def test_memory_context_provider_recent_turns_can_skip_tool_call_groups(tm state=provider_state, ) with_tools_agent = Agent( - client=_MemoryHarnessClient(), + client=_MemoryHarnessClient(), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[MemoryContextProvider(store=store, recent_turns=2, load_tool_turns=True)], - default_options={"store": False}, + default_options=_no_store_options(), ) without_tools_agent = Agent( - client=_MemoryHarnessClient(), + client=_MemoryHarnessClient(), # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[MemoryContextProvider(store=store, recent_turns=2, load_tool_turns=False)], - default_options={"store": False}, + default_options=_no_store_options(), ) - with_tools_context, _ = await with_tools_agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + with_tools_context, _ = await with_tools_agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Current question"])], ) - without_tools_context, _ = await without_tools_agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + without_tools_context, _ = await without_tools_agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Current question"])], ) @@ -521,15 +526,15 @@ async def test_memory_context_provider_uses_explicit_consolidation_client(tmp_pa ) provider = MemoryContextProvider( store=store, - consolidation_client=consolidation_client, + consolidation_client=consolidation_client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] ) agent = Agent( - client=main_client, + client=main_client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] context_providers=[provider], - default_options={"store": False}, + default_options=_no_store_options(), ) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Remember this."])], ) @@ -539,8 +544,8 @@ async def test_memory_context_provider_uses_explicit_consolidation_client(tmp_pa write_memory = _tool_by_name(tools, "write_memory") consolidate_memories = _tool_by_name(tools, "consolidate_memories") - await write_memory.invoke(arguments={"topic": "travel", "memory": "Visit Oslo in June."}) - await consolidate_memories.invoke() + await write_memory.invoke(arguments={"topic": "travel", "memory": "Visit Oslo in June."}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + await consolidate_memories.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] travel_topic = store.get_topic(session, source_id=DEFAULT_MEMORY_SOURCE_ID, topic="travel") assert travel_topic.summary == "Consolidated by the cheaper client." @@ -554,9 +559,9 @@ async def test_memory_context_provider_preserves_concurrent_writes_to_same_topic session.state["owner_id"] = "alice" store = MemoryFileStore(tmp_path, owner_state_key="owner_id") provider = MemoryContextProvider(store=store) - agent = Agent(client=_MemoryHarnessClient(), context_providers=[provider], default_options={"store": False}) + agent = Agent(client=_MemoryHarnessClient(), context_providers=[provider], default_options=_no_store_options()) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Remember these."])], ) @@ -566,7 +571,7 @@ async def test_memory_context_provider_preserves_concurrent_writes_to_same_topic memories = [f"Concurrent memory {index}." for index in range(20)] await asyncio.gather( - *(write_memory.invoke(arguments={"topic": "preferences", "memory": memory}) for memory in memories) + *(write_memory.invoke(arguments={"topic": "preferences", "memory": memory}) for memory in memories) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ) topic = store.get_topic(session, source_id=DEFAULT_MEMORY_SOURCE_ID, topic="preferences") @@ -575,12 +580,12 @@ async def test_memory_context_provider_preserves_concurrent_writes_to_same_topic def test_memory_harness_classes_are_marked_experimental() -> None: """Memory harness public classes should expose HARNESS experimental metadata.""" - assert MemoryIndexEntry.__feature_id__ == ExperimentalFeature.HARNESS.value - assert MemoryTopicRecord.__feature_id__ == ExperimentalFeature.HARNESS.value - assert MemoryStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert MemoryFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert MemoryContextProvider.__feature_id__ == ExperimentalFeature.HARNESS.value - assert ".. warning:: Experimental" in MemoryContextProvider.__doc__ + assert MemoryIndexEntry.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert MemoryTopicRecord.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert MemoryStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert MemoryFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert MemoryContextProvider.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ".. warning:: Experimental" in MemoryContextProvider.__doc__ # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_memory_topic_record_round_trips_when_text_contains_section_markers() -> None: @@ -713,7 +718,7 @@ async def test_memory_consolidation_transient_failure_preserves_state(tmp_path) session.state["owner_id"] = "alice" store = MemoryFileStore(tmp_path, owner_state_key="owner_id") raising_client = _RaisingMemoryClient() - provider = MemoryContextProvider(store=store, consolidation_client=raising_client) + provider = MemoryContextProvider(store=store, consolidation_client=raising_client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] pre_state = { "last_consolidated_at": "2026-04-20T09:00:00+00:00", "sessions_since_consolidation": ["queued-session"], @@ -731,8 +736,8 @@ async def test_memory_consolidation_transient_failure_preserves_state(tmp_path) source_id=DEFAULT_MEMORY_SOURCE_ID, ) - consolidated_count = await provider._run_consolidation( # type: ignore[reportPrivateUsage] - client=raising_client, + consolidated_count = await provider._run_consolidation( # pyright: ignore[reportPrivateUsage] + client=raising_client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, force=True, now=datetime(2026, 4, 22, tzinfo=timezone.utc), @@ -759,11 +764,11 @@ async def test_memory_extraction_propagates_programmer_errors(tmp_path) -> None: context = SessionContext( input_messages=[Message(role="user", contents=["q"])], ) - context._response = AgentResponse(messages=[Message(role="assistant", contents=["a"])]) # type: ignore[reportPrivateUsage] + context._response = AgentResponse(messages=[Message(role="assistant", contents=["a"])]) # pyright: ignore[reportPrivateUsage] with pytest.raises(AttributeError, match="misconfigured client"): - await provider._extract_memories( # type: ignore[reportPrivateUsage] - client=bad_client, + await provider._extract_memories( # pyright: ignore[reportPrivateUsage] + client=bad_client, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=context, now=datetime(2026, 4, 22, tzinfo=timezone.utc), diff --git a/python/packages/core/tests/core/test_harness_mode.py b/python/packages/core/tests/core/test_harness_mode.py index 915379fbd85..12d31b1de26 100644 --- a/python/packages/core/tests/core/test_harness_mode.py +++ b/python/packages/core/tests/core/test_harness_mode.py @@ -69,10 +69,10 @@ def test_agent_mode_context_provider_validates_configuration_and_is_experimental with pytest.raises(ValueError, match="Invalid mode"): AgentModeProvider(default_mode="ship") - assert AgentModeProvider.__feature_id__ == ExperimentalFeature.HARNESS.value - assert get_agent_mode.__feature_id__ == ExperimentalFeature.HARNESS.value - assert set_agent_mode.__feature_id__ == ExperimentalFeature.HARNESS.value - assert ".. warning:: Experimental" in AgentModeProvider.__doc__ + assert AgentModeProvider.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert get_agent_mode.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert set_agent_mode.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ".. warning:: Experimental" in AgentModeProvider.__doc__ # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] assert get_agent_mode.__doc__ is not None assert ".. warning:: Experimental" in get_agent_mode.__doc__ assert set_agent_mode.__doc__ is not None @@ -89,7 +89,7 @@ async def test_agent_mode_context_provider_normalizes_custom_modes( ) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Start drafting"])], ) @@ -121,7 +121,7 @@ async def test_agent_mode_context_provider_serializes_tool_outputs_as_json( provider = AgentModeProvider(default_mode=mode_name, mode_descriptions={mode_name: "Preview edits."}) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Preview edits"])], ) @@ -130,10 +130,10 @@ async def test_agent_mode_context_provider_serializes_tool_outputs_as_json( get_mode_tool = _tool_by_name(tools, "mode_get") set_mode_tool = _tool_by_name(tools, "mode_set") - initial_mode = await get_mode_tool.invoke() + initial_mode = await get_mode_tool.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(initial_mode[0].text) == {"mode": mode_name} - set_result = await set_mode_tool.invoke(arguments={"mode": mode_name}) + set_result = await set_mode_tool.invoke(arguments={"mode": mode_name}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(set_result[0].text) == {"mode": mode_name, "message": f"Mode changed to '{mode_name}'."} @@ -145,7 +145,7 @@ async def test_agent_mode_context_provider_updates_agent_mode( provider = AgentModeProvider() agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Start planning"])], ) @@ -162,10 +162,10 @@ async def test_agent_mode_context_provider_updates_agent_mode( get_mode_tool = _tool_by_name(tools, "mode_get") set_mode_tool = _tool_by_name(tools, "mode_set") - initial_mode = await get_mode_tool.invoke() + initial_mode = await get_mode_tool.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(initial_mode[0].text) == {"mode": "plan"} - set_result = await set_mode_tool.invoke(arguments={"mode": "execute"}) + set_result = await set_mode_tool.invoke(arguments={"mode": "execute"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(set_result[0].text) == {"mode": "execute", "message": "Mode changed to 'execute'."} assert get_agent_mode(session, source_id=provider.source_id) == "execute" assert set_agent_mode(session, "plan", source_id=provider.source_id) == "plan" @@ -222,12 +222,12 @@ async def test_agent_mode_provider_injects_user_message_after_external_change( # First run: agent uses mode_set tool to switch to execute. The tool path must NOT queue a # notification because the agent already saw its own tool call in the chat history. - _, first_options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, first_options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Plan first."])], ) set_mode_tool = _tool_by_name(first_options["tools"], "mode_set") - await set_mode_tool.invoke(arguments={"mode": "execute"}) + await set_mode_tool.invoke(arguments={"mode": "execute"}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "previous_mode_for_notification" not in session.state[provider.source_id] # Now an external caller (e.g., a /mode slash command) switches the mode back to plan. @@ -235,7 +235,7 @@ async def test_agent_mode_provider_injects_user_message_after_external_change( assert session.state[provider.source_id]["previous_mode_for_notification"] == "execute" # Next run: the provider should inject a user message announcing the change and clear the flag. - second_context, second_options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + second_context, second_options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Carry on."])], ) @@ -252,7 +252,7 @@ async def test_agent_mode_provider_injects_user_message_after_external_change( assert "previous_mode_for_notification" not in session.state[provider.source_id] # Third run with no further external change must not re-inject the notification. - third_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + third_context, _ = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Status?"])], ) diff --git a/python/packages/core/tests/core/test_harness_todo.py b/python/packages/core/tests/core/test_harness_todo.py index ed14b8493cf..cac4e9c6fd3 100644 --- a/python/packages/core/tests/core/test_harness_todo.py +++ b/python/packages/core/tests/core/test_harness_todo.py @@ -43,7 +43,7 @@ def test_todo_item_round_trips_with_value_equality() -> None: item = TodoItem.from_dict(raw_item) - assert item == TodoItem(**raw_item) + assert item == TodoItem(**raw_item) # type: ignore[arg-type] assert item.to_dict() == raw_item assert json.loads(item.to_json()) == raw_item assert "TodoItem(" in repr(item) @@ -203,12 +203,12 @@ async def test_todo_provider_evicts_locks_when_session_is_garbage_collected() -> provider = TodoProvider() session = AgentSession(session_id="session-1") - provider._mutation_lock(session) # type: ignore[reportPrivateUsage] - assert len(provider._mutation_locks) == 1 # type: ignore[reportPrivateUsage] + provider._mutation_lock(session) # pyright: ignore[reportPrivateUsage] + assert len(provider._mutation_locks) == 1 # pyright: ignore[reportPrivateUsage] del session gc.collect() - assert len(provider._mutation_locks) == 0 # type: ignore[reportPrivateUsage] + assert len(provider._mutation_locks) == 0 # pyright: ignore[reportPrivateUsage] async def test_todo_file_store_rejects_session_path_traversal(tmp_path: Path) -> None: @@ -245,7 +245,7 @@ async def test_todo_provider_runs_with_file_store(tmp_path: Path, chat_client_ba provider = TodoProvider(store=TodoFileStore(tmp_path)) agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Track this work"])], ) @@ -255,14 +255,14 @@ async def test_todo_provider_runs_with_file_store(tmp_path: Path, chat_client_ba add_todos = _tool_by_name(tools, "todos_add") get_all_todos = _tool_by_name(tools, "todos_get_all") - await add_todos.invoke(arguments={"todos": [{"title": "Persist me"}]}) + await add_todos.invoke(arguments={"todos": [{"title": "Persist me"}]}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] state_path = tmp_path / "session-1" / "todos.todo.json" assert state_path.exists() persisted = json.loads(state_path.read_text(encoding="utf-8")) assert persisted["items"] == [{"id": 1, "title": "Persist me", "description": None, "is_complete": False}] assert persisted["next_id"] == 2 - get_all_result = await get_all_todos.invoke() + get_all_result = await get_all_todos.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(get_all_result[0].text) == [ {"id": 1, "title": "Persist me", "description": None, "is_complete": False} ] @@ -276,7 +276,7 @@ async def test_todo_provider_tools_manage_session_state( provider = TodoProvider() agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Track this work"])], ) @@ -289,7 +289,7 @@ async def test_todo_provider_tools_manage_session_state( get_remaining_todos = _tool_by_name(tools, "todos_get_remaining") get_all_todos = _tool_by_name(tools, "todos_get_all") - add_result = await add_todos.invoke( + add_result = await add_todos.invoke( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] arguments={ "todos": [ {"title": " Write tests ", "description": " Cover stores "}, @@ -302,18 +302,18 @@ async def test_todo_provider_tools_manage_session_state( {"id": 2, "title": "Ship feature", "description": None, "is_complete": False}, ] - complete_result = await complete_todos.invoke(arguments={"items": [{"id": 1, "reason": "Tests written"}]}) + complete_result = await complete_todos.invoke(arguments={"items": [{"id": 1, "reason": "Tests written"}]}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(complete_result[0].text) == {"completed": 1} - remaining_result = await get_remaining_todos.invoke() + remaining_result = await get_remaining_todos.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(remaining_result[0].text) == [ {"id": 2, "title": "Ship feature", "description": None, "is_complete": False} ] - remove_result = await remove_todos.invoke(arguments={"ids": [2]}) + remove_result = await remove_todos.invoke(arguments={"ids": [2]}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(remove_result[0].text) == {"removed": 1} - get_all_result = await get_all_todos.invoke() + get_all_result = await get_all_todos.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert json.loads(get_all_result[0].text) == [ {"id": 1, "title": "Write tests", "description": "Cover stores", "is_complete": True} ] @@ -327,7 +327,7 @@ async def test_todo_provider_serializes_concurrent_mutations( provider = TodoProvider() agent = Agent(client=chat_client_base, context_providers=[provider]) - _, options = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage] + _, options = await agent._prepare_session_and_messages( # pyright: ignore[reportPrivateUsage] session=session, input_messages=[Message(role="user", contents=["Track this work"])], ) @@ -338,15 +338,15 @@ async def test_todo_provider_serializes_concurrent_mutations( complete_todos = _tool_by_name(tools, "todos_complete") get_all_todos = _tool_by_name(tools, "todos_get_all") - await add_todos.invoke(arguments={"todos": [{"title": f"Existing {index}"} for index in range(1, 6)]}) + await add_todos.invoke(arguments={"todos": [{"title": f"Existing {index}"} for index in range(1, 6)]}) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] await asyncio.gather( - add_todos.invoke(arguments={"todos": [{"title": "Add A1"}, {"title": "Add A2"}]}), - add_todos.invoke(arguments={"todos": [{"title": "Add B1"}, {"title": "Add B2"}]}), - complete_todos.invoke(arguments={"items": [{"id": i, "reason": "Done"} for i in range(1, 6)]}), + add_todos.invoke(arguments={"todos": [{"title": "Add A1"}, {"title": "Add A2"}]}), # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + add_todos.invoke(arguments={"todos": [{"title": "Add B1"}, {"title": "Add B2"}]}), # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + complete_todos.invoke(arguments={"items": [{"id": i, "reason": "Done"} for i in range(1, 6)]}), # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] ) - get_all_result = await get_all_todos.invoke() + get_all_result = await get_all_todos.invoke() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] payload = json.loads(get_all_result[0].text) ids = [item["id"] for item in payload] @@ -368,10 +368,10 @@ async def test_todo_provider_serializes_concurrent_mutations( def test_todo_harness_classes_are_marked_experimental() -> None: """Todo harness public classes should expose HARNESS experimental metadata.""" - assert TodoStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert TodoItem.__feature_id__ == ExperimentalFeature.HARNESS.value - assert TodoInput.__feature_id__ == ExperimentalFeature.HARNESS.value - assert TodoSessionStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert TodoFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value - assert TodoProvider.__feature_id__ == ExperimentalFeature.HARNESS.value - assert ".. warning:: Experimental" in TodoProvider.__doc__ + assert TodoStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TodoItem.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TodoInput.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TodoSessionStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TodoFileStore.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert TodoProvider.__feature_id__ == ExperimentalFeature.HARNESS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert ".. warning:: Experimental" in TodoProvider.__doc__ # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] diff --git a/python/packages/core/tests/core/test_hyperlight_namespace.py b/python/packages/core/tests/core/test_hyperlight_namespace.py index f76d6180b9c..1a09a883996 100644 --- a/python/packages/core/tests/core/test_hyperlight_namespace.py +++ b/python/packages/core/tests/core/test_hyperlight_namespace.py @@ -24,7 +24,7 @@ def test_hyperlight_namespace_dir_lists_lazy_exports() -> None: def test_hyperlight_namespace_lazy_loads_known_attribute(monkeypatch: pytest.MonkeyPatch) -> None: sentinel = object() fake_module = ModuleType("agent_framework_hyperlight") - fake_module.HyperlightCodeActProvider = sentinel # type: ignore[attr-defined] + fake_module.HyperlightCodeActProvider = sentinel # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] monkeypatch.setitem(sys.modules, "agent_framework_hyperlight", fake_module) assert hyperlight.HyperlightCodeActProvider is sentinel diff --git a/python/packages/core/tests/core/test_local_eval.py b/python/packages/core/tests/core/test_local_eval.py index e60fb35d514..a4c2d3c21a5 100644 --- a/python/packages/core/tests/core/test_local_eval.py +++ b/python/packages/core/tests/core/test_local_eval.py @@ -62,7 +62,7 @@ async def test_bool_return_true(self): def has_temperature(query: str, response: str) -> bool: return "°F" in response - result = await has_temperature(_make_item()) + result = await has_temperature(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert result.check_name == "has_temperature" @@ -72,7 +72,7 @@ async def test_bool_return_false(self): def has_celsius(query: str, response: str) -> bool: return "°C" in response - result = await has_celsius(_make_item()) + result = await has_celsius(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is False @pytest.mark.asyncio @@ -81,7 +81,7 @@ async def test_float_return_passing(self): def length_score(response: str) -> float: return min(len(response) / 10, 1.0) - result = await length_score(_make_item()) + result = await length_score(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert "score=" in result.reason @@ -91,7 +91,7 @@ async def test_float_return_failing(self): def always_low(response: str) -> float: return 0.1 - result = await always_low(_make_item()) + result = await always_low(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is False @pytest.mark.asyncio @@ -102,7 +102,7 @@ async def test_response_only(self): def is_short(response: str) -> bool: return len(response) < 1000 - result = await is_short(_make_item()) + result = await is_short(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @pytest.mark.asyncio @@ -113,7 +113,7 @@ async def test_query_only(self): def is_question(query: str) -> bool: return "?" in query - result = await is_question(_make_item()) + result = await is_question(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @@ -130,10 +130,10 @@ def exact_match(response: str, expected_output: str) -> bool: return response.strip() == expected_output.strip() item = _make_item(response="42", expected_output="42") - assert (await exact_match(item)).passed is True + assert (await exact_match(item)).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] item2 = _make_item(response="43", expected_output="42") - assert (await exact_match(item2)).passed is False + assert (await exact_match(item2)).passed is False # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_expected_output_defaults_to_empty(self): @@ -143,7 +143,7 @@ async def test_expected_output_defaults_to_empty(self): def check_expected(expected_output: str) -> bool: return expected_output == "" - result = await check_expected(_make_item(expected_output=None)) + result = await check_expected(_make_item(expected_output=None)) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @pytest.mark.asyncio @@ -157,7 +157,7 @@ def word_overlap(response: str, expected_output: str) -> float: return len(r_words & e_words) / len(e_words) item = _make_item(response="sunny warm day", expected_output="warm sunny afternoon") - result = await word_overlap(item) + result = await word_overlap(item) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True # 2/3 overlap ≥ 0.5 @@ -174,10 +174,10 @@ def multi_turn(query: str, response: str, *, conversation: list) -> bool: return len(conversation) >= 2 item = _make_item(conversation=[Message("user", []), Message("assistant", [])]) - assert (await multi_turn(item)).passed is True + assert (await multi_turn(item)).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] item2 = _make_item(conversation=[Message("user", [])]) - assert (await multi_turn(item2)).passed is False + assert (await multi_turn(item2)).passed is False # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_tools_access(self): @@ -191,7 +191,7 @@ def has_tools(tools: list) -> bool: {"name": "get_weather", "description": "Get weather", "parameters": lambda self: {}}, )() item = _make_item(tools=[mock_tool]) - assert (await has_tools(item)).passed is True + assert (await has_tools(item)).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_context_access(self): @@ -202,7 +202,7 @@ def grounded(response: str, context: str) -> bool: return any(word in response.lower() for word in context.lower().split()) item = _make_item(response="It's sunny", context="sunny warm") - assert (await grounded(item)).passed is True + assert (await grounded(item)).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_all_params(self): @@ -218,7 +218,7 @@ def full_check( return all([query, response, expected_output is not None, isinstance(conversation, list)]) item = _make_item(expected_output="foo", context="bar") - assert (await full_check(item)).passed is True + assert (await full_check(item)).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] # --------------------------------------------------------------------------- @@ -233,7 +233,7 @@ async def test_dict_with_score(self): def scored(response: str) -> dict: return {"score": 0.9, "reason": "good answer"} - result = await scored(_make_item()) + result = await scored(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert result.reason == "good answer" @@ -243,7 +243,7 @@ async def test_dict_with_score_below_threshold(self): def low_scored(response: str) -> dict: return {"score": 0.3} - result = await low_scored(_make_item()) + result = await low_scored(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is False @pytest.mark.asyncio @@ -252,7 +252,7 @@ async def test_dict_with_custom_threshold(self): def custom_threshold(response: str) -> dict: return {"score": 0.3, "threshold": 0.2} - result = await custom_threshold(_make_item()) + result = await custom_threshold(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @pytest.mark.asyncio @@ -261,7 +261,7 @@ async def test_dict_with_passed(self): def explicit_pass(response: str) -> dict: return {"passed": True, "reason": "all good"} - result = await explicit_pass(_make_item()) + result = await explicit_pass(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert result.reason == "all good" @@ -271,7 +271,7 @@ async def test_check_result_passthrough(self): def returns_check_result(response: str) -> CheckResult: return CheckResult(True, "direct result", "custom") - result = await returns_check_result(_make_item()) + result = await returns_check_result(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert result.reason == "direct result" assert result.check_name == "custom" @@ -283,7 +283,7 @@ def bad_return(response: str) -> str: return "oops" with pytest.raises(TypeError, match="unsupported type"): - await bad_return(_make_item()) + await bad_return(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_int_return(self): @@ -291,7 +291,7 @@ async def test_int_return(self): def int_score(response: str) -> int: return 1 - result = await int_score(_make_item()) + result = await int_score(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @@ -307,7 +307,7 @@ async def test_decorator_no_parens(self): def my_check(response: str) -> bool: return True - assert (await my_check(_make_item())).passed is True + assert (await my_check(_make_item())).passed is True # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] @pytest.mark.asyncio async def test_decorator_with_name(self): @@ -316,7 +316,7 @@ def my_check(response: str) -> bool: return True assert my_check.__name__ == "custom_name" - result = await my_check(_make_item()) + result = await my_check(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.check_name == "custom_name" @pytest.mark.asyncio @@ -324,7 +324,7 @@ async def test_direct_call(self): def raw_fn(query: str, response: str) -> bool: return len(response) > 0 - check = evaluator(raw_fn, name="direct") + check = evaluator(raw_fn, name="direct") # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] result = await check(_make_item()) assert result.passed is True assert result.check_name == "direct" @@ -350,7 +350,7 @@ async def test_unknown_optional_param_ok(self): def optional_unknown(query: str, foo: str = "default") -> bool: return foo == "default" - result = await optional_unknown(_make_item()) + result = await optional_unknown(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @pytest.mark.asyncio @@ -365,7 +365,7 @@ async def async_fn(response: str) -> bool: # Should return an awaitable assert inspect.isawaitable(result) check_result = await result - assert check_result.passed is True + assert check_result.passed is True # ty: ignore[unresolved-attribute] # --------------------------------------------------------------------------- @@ -390,8 +390,8 @@ def length_ok(response: str) -> bool: results = await local.evaluate(items, eval_name="mixed test") assert results.status == "completed" - assert results.result_counts["passed"] == 1 - assert results.result_counts["failed"] == 0 + assert results.result_counts["passed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert results.result_counts["failed"] == 0 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.asyncio async def test_evaluator_failure_counted(self): @@ -402,7 +402,7 @@ def always_fail(response: str) -> bool: local = LocalEvaluator(always_fail) results = await local.evaluate([_make_item()]) - assert results.result_counts["failed"] == 1 + assert results.result_counts["failed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.asyncio async def test_multiple_evaluators(self): @@ -421,7 +421,7 @@ def check_c(response: str, conversation: list) -> dict: local = LocalEvaluator(check_a, check_b, check_c) results = await local.evaluate([_make_item(expected_output="test")]) - assert results.result_counts["passed"] == 1 + assert results.result_counts["passed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] assert "check_a" in results.per_evaluator assert "check_b" in results.per_evaluator assert "check_c" in results.per_evaluator @@ -441,7 +441,7 @@ async def async_check(query: str, response: str) -> bool: local = LocalEvaluator(async_check) results = await local.evaluate([_make_item()]) - assert results.result_counts["passed"] == 1 + assert results.result_counts["passed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.asyncio async def test_async_with_name(self): @@ -449,7 +449,7 @@ async def test_async_with_name(self): async def my_async(response: str) -> float: return 0.75 - result = await my_async(_make_item()) + result = await my_async(_make_item()) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True assert result.check_name == "named_async" @@ -472,7 +472,7 @@ def is_long(response: str) -> bool: items = [_make_item(response="It is sunny and warm today")] results = await _run_evaluators(is_long, items, eval_name="test") assert len(results) == 1 - assert results[0].result_counts["passed"] == 1 + assert results[0].result_counts["passed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.asyncio async def test_mixed_evaluators_and_checks(self): @@ -488,7 +488,7 @@ def has_words(response: str) -> bool: items = [_make_item(response="It is sunny")] results = await _run_evaluators([local, has_words], items, eval_name="test") assert len(results) == 2 - assert all(r.result_counts["passed"] == 1 for r in results) + assert all(r.result_counts["passed"] == 1 for r in results) # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.asyncio async def test_adjacent_checks_grouped(self): @@ -507,7 +507,7 @@ def check_b(response: str) -> bool: results = await _run_evaluators([check_a, check_b], items, eval_name="test") # Two adjacent checks → one LocalEvaluator → one result assert len(results) == 1 - assert results[0].result_counts["passed"] == 1 + assert results[0].result_counts["passed"] == 1 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # --------------------------------------------------------------------------- @@ -642,7 +642,7 @@ def check_tools(expected_tool_calls: list) -> bool: calls=[], expected=[ExpectedToolCall("a"), ExpectedToolCall("b")], ) - result = await check_tools(item) + result = await check_tools(item) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @pytest.mark.asyncio @@ -652,7 +652,7 @@ def check_tools(expected_tool_calls: list) -> bool: return len(expected_tool_calls) == 0 item = _make_tool_call_item(calls=[]) - result = await check_tools(item) + result = await check_tools(item) # type: ignore[misc] # pyrefly: ignore[not-async] # ty: ignore[invalid-await] assert result.passed is True @@ -861,7 +861,7 @@ async def test_any_mode_one_tool_called(self): ) check = tool_called_check("tool_a", "tool_b", mode="any") result = check(item) - assert result.passed is True + assert result.passed is True # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_any_mode_none_called(self): """mode='any' fails when no expected tools are called.""" @@ -873,8 +873,8 @@ async def test_any_mode_none_called(self): ) check = tool_called_check("tool_a", "tool_b", mode="any") result = check(item) - assert result.passed is False - assert "None of expected tools" in result.reason + assert result.passed is False # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert "None of expected tools" in result.reason # type: ignore[union-attr] # ty: ignore[unresolved-attribute] class TestCoerceResultScoreError: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 7c45296cbb0..72a26adfa76 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -# type: ignore[reportPrivateUsage] +# pyright: ignore[reportPrivateUsage] import asyncio import contextlib import json @@ -53,7 +53,7 @@ def _mcp_result_to_text(result: str | list[Content]) -> str: return text or str(result) -_HELPER_MCP_TOOL = MCPTool(name="helper") +_HELPER_MCP_TOOL = MCPTool(name="helper") # type: ignore[abstract] # Helper function tests @@ -93,7 +93,7 @@ def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None: async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration(): """Prefixed MCP tool names should still honor unprefixed allow/approval configuration.""" - tool = MCPTool( + tool = MCPTool( # type: ignore[abstract] name="docs", tool_name_prefix="docs", allowed_tools=["search_docs"], @@ -124,7 +124,7 @@ async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration async def test_load_prompts_with_tool_name_prefix() -> None: """Prefixed MCP prompt names should be exposed with the configured prefix.""" - tool = MCPTool(name="docs", tool_name_prefix="docs") + tool = MCPTool(name="docs", tool_name_prefix="docs") # type: ignore[abstract] mock_session = AsyncMock() tool.session = mock_session @@ -160,7 +160,7 @@ def test_mcp_prompt_message_to_ai_content(): def test_mcp_tool_str_and_parse_prompt_result_rich_content() -> None: - tool = MCPTool(name="helper", description="Helper MCP tool") + tool = MCPTool(name="helper", description="Helper MCP tool") # type: ignore[abstract] prompt_result = types.GetPromptResult( messages=[ types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello")), @@ -228,12 +228,12 @@ def test_parse_tool_result_from_mcp(): assert result[0].text == "Result text" assert result[1].type == "data" assert result[1].media_type == "image/png" - assert "eHl6" in result[1].uri + assert "eHl6" in result[1].uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] assert result[2].type == "text" assert result[2].text == "After image" assert result[3].type == "data" assert result[3].media_type == "image/webp" - assert "YWJj" in result[3].uri + assert "YWJj" in result[3].uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_tool_result_from_mcp_single_text(): @@ -287,7 +287,7 @@ def test_parse_tool_result_from_mcp_audio_content(): assert len(result) == 1 assert result[0].type == "data" assert result[0].media_type == "audio/wav" - assert "YXVkaW8=" in result[0].uri + assert "YXVkaW8=" in result[0].uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_tool_result_from_mcp_blob_plain_base64(): @@ -310,7 +310,7 @@ def test_parse_tool_result_from_mcp_blob_plain_base64(): assert len(result) == 1 assert result[0].type == "data" assert result[0].media_type == "application/pdf" - assert "dGVzdCBkYXRh" in result[0].uri + assert "dGVzdCBkYXRh" in result[0].uri # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_tool_result_from_mcp_resource_link_text_resource_and_unknown(): @@ -489,7 +489,7 @@ def test_ai_content_to_mcp_content_types_data_binary(): assert isinstance(mcp_content, types.EmbeddedResource) assert mcp_content.type == "resource" - assert mcp_content.resource.blob == "data:application/octet-stream;base64,xyz" + assert mcp_content.resource.blob == "data:application/octet-stream;base64,xyz" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert mcp_content.resource.mimeType == "application/octet-stream" @@ -519,7 +519,7 @@ def test_prepare_message_for_mcp(): def test_prepare_message_for_mcp_skips_unsupported_content() -> None: - unsupported = Content(type="annotations", text="ignored") + unsupported = Content(type="annotations", text="ignored") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert _HELPER_MCP_TOOL._prepare_content_for_mcp(unsupported) is None @@ -534,7 +534,7 @@ def test_prepare_message_for_mcp_skips_unsupported_content() -> None: "test_id,input_schema", [ (test_id, input_schema) - for test_id, input_schema, _, _, _, _ in [ + for test_id, input_schema, _, _, _, _ in [ # type: ignore[assignment] # Basic types with required/optional fields ( "basic_types", @@ -894,7 +894,7 @@ def test_get_input_model_from_mcp_prompt_without_arguments(): # MCPTool tests async def test_local_mcp_server_initialization(): """Test MCPTool initialization.""" - server = MCPTool(name="test_server") + server = MCPTool(name="test_server") # type: ignore[abstract] # MCPTool has the same core attributes as FunctionTool assert hasattr(server, "name") assert hasattr(server, "description") @@ -908,12 +908,12 @@ async def test_local_mcp_server_context_manager(): """Test MCPTool as context manager.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] # Mock connection self.session = Mock(spec=ClientSession) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -926,7 +926,7 @@ async def test_local_mcp_server_load_functions(): """Test loading functions from MCP server.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) # Mock tools list response self.session.list_tools = AsyncMock( @@ -946,7 +946,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") # MCPTool has the same core attributes as FunctionTool @@ -962,7 +962,7 @@ async def test_local_mcp_server_load_prompts(): """Test loading prompts from MCP server.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) # Mock prompts list response self.session.list_prompts = AsyncMock( @@ -978,7 +978,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -991,7 +991,7 @@ async def test_mcp_tool_call_tool_with_meta_integration(): """Test that call_tool method properly integrates with enhanced metadata extraction.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1018,7 +1018,7 @@ async def connect(self): self.session.call_tool = AsyncMock(return_value=tool_result) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1036,7 +1036,7 @@ async def test_local_mcp_server_function_execution(): """Test function execution through MCP server.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1060,7 +1060,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1076,7 +1076,7 @@ async def test_local_mcp_server_function_execution_with_nested_object(): """Test function execution through MCP server with nested object arguments.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1106,7 +1106,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1120,8 +1120,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: assert result[0].text == '{"name": "John Doe", "id": 251}' # Verify the session.call_tool was called with the correct nested structure - server.session.call_tool.assert_called_once() - call_args = server.session.call_tool.call_args + server.session.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + call_args = server.session.call_tool.call_args # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert call_args.kwargs["arguments"] == {"params": {"customer_id": 251}} @@ -1129,7 +1129,7 @@ async def test_local_mcp_server_function_execution_error(): """Test function execution error handling.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1152,7 +1152,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1187,7 +1187,7 @@ async def connect(self, *, reset: bool = False) -> None: self.is_connected = True def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") await server.connect() @@ -1204,7 +1204,7 @@ async def test_mcp_tool_call_tool_raises_on_is_error(): """Test that call_tool raises ToolExecutionException when MCP returns isError=True.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1229,7 +1229,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1244,7 +1244,7 @@ async def test_mcp_tool_call_tool_succeeds_when_is_error_false(): """Test that call_tool returns normally when MCP returns isError=False.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1269,7 +1269,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1294,7 +1294,7 @@ async def process(self, context: FunctionInvocationContext, call_next): raise class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1319,7 +1319,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -1346,7 +1346,7 @@ async def test_local_mcp_server_prompt_execution(): """Test prompt execution through MCP server.""" class TestMCPTool(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_prompts = AsyncMock( return_value=types.ListPromptsResult( @@ -1372,7 +1372,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestMCPTool(name="test_server") async with server: @@ -1409,7 +1409,7 @@ async def test_mcp_tool_approval_mode(approval_mode, expected_approvals): """ class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1435,7 +1435,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server", approval_mode=approval_mode) async with server: @@ -1448,7 +1448,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: def test_mcp_tool_approval_mode_returns_none_for_unmatched_names() -> None: - tool = MCPTool( + tool = MCPTool( # type: ignore[abstract] name="test_tool", approval_mode={ "always_require_approval": ["tool_one"], @@ -1485,7 +1485,7 @@ async def test_mcp_tool_allowed_tools(allowed_tools, expected_count, expected_na """ class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -1519,7 +1519,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server", allowed_tools=allowed_tools) async with server: @@ -1599,7 +1599,7 @@ async def test_mcp_connection_reset_integration(): """ url = os.environ.get("LOCAL_MCP_URL") - tool = MCPStreamableHTTPTool(name="integration_test", url=url, approval_mode="never_require") + tool = MCPStreamableHTTPTool(name="integration_test", url=url, approval_mode="never_require") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async with tool: # Verify initial connection @@ -1632,7 +1632,7 @@ async def call_tool_with_error(*args, **kwargs): # After reconnection, delegate to the original method return await original_call_tool(*args, **kwargs) - tool.session.call_tool = call_tool_with_error + tool.session.call_tool = call_tool_with_error # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Invoke the function again - this should trigger automatic reconnection on ClosedResourceError second_result = _mcp_result_to_text(await func.invoke(query="What is Agent Framework?")) @@ -1659,39 +1659,39 @@ async def test_mcp_tool_message_handler_notification(): tool = MCPStdioTool(name="test_tool", command="python") # Mock the load_tools and load_prompts methods - tool.load_tools = AsyncMock() - tool.load_prompts = AsyncMock() + tool.load_tools = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] + tool.load_prompts = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Test tools list changed notification tools_notification = Mock(spec=types.ServerNotification) tools_notification.root = Mock() tools_notification.root.method = "notifications/tools/list_changed" - result = await tool.message_handler(tools_notification) + result = await tool.message_handler(tools_notification) # type: ignore[func-returns-value] assert result is None # The reload is scheduled as a background task; let it run. await asyncio.sleep(0) - tool.load_tools.assert_called_once() + tool.load_tools.assert_called_once() # ty: ignore[unresolved-attribute] # Reset mock - tool.load_tools.reset_mock() + tool.load_tools.reset_mock() # ty: ignore[unresolved-attribute] # Test prompts list changed notification prompts_notification = Mock(spec=types.ServerNotification) prompts_notification.root = Mock() prompts_notification.root.method = "notifications/prompts/list_changed" - result = await tool.message_handler(prompts_notification) + result = await tool.message_handler(prompts_notification) # type: ignore[func-returns-value] assert result is None await asyncio.sleep(0) - tool.load_prompts.assert_called_once() + tool.load_prompts.assert_called_once() # ty: ignore[unresolved-attribute] # Test unhandled notification unknown_notification = Mock(spec=types.ServerNotification) unknown_notification.root = Mock() unknown_notification.root.method = "notifications/unknown" - result = await tool.message_handler(unknown_notification) + result = await tool.message_handler(unknown_notification) # type: ignore[func-returns-value] assert result is None @@ -1703,7 +1703,7 @@ async def test_mcp_tool_message_handler_error(): test_exception = RuntimeError("Test error message") # The message handler should log the error and return None - result = await tool.message_handler(test_exception) + result = await tool.message_handler(test_exception) # type: ignore[func-returns-value] assert result is None @@ -1726,7 +1726,7 @@ async def test_mcp_tool_message_handler_does_not_block_receive_loop(): async def slow_load_tools(): await release.wait() - tool.load_tools = slow_load_tools # type: ignore[assignment] + tool.load_tools = slow_load_tools # type: ignore[assignment] # ty: ignore[invalid-assignment] tools_notification = Mock(spec=types.ServerNotification) tools_notification.root = Mock() @@ -1750,7 +1750,7 @@ async def slow_load_tools(): async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest.LogCaptureFixture): """Background reload errors are logged, not raised into the receive loop.""" tool = MCPStdioTool(name="test_tool", command="python") - tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost")) + tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] tools_notification = Mock(spec=types.ServerNotification) tools_notification.root = Mock() @@ -1762,7 +1762,7 @@ async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest. pending = list(tool._pending_reload_tasks) if pending: await asyncio.wait_for(asyncio.gather(*pending, return_exceptions=True), timeout=1) - tool.load_tools.assert_called_once() + tool.load_tools.assert_called_once() # ty: ignore[unresolved-attribute] assert len(tool._pending_reload_tasks) == 0 # Verify the warning was actually logged with exception info. @@ -1784,7 +1784,7 @@ async def blocking_load_tools(): call_count += 1 await release.wait() - tool.load_tools = blocking_load_tools # type: ignore[assignment] + tool.load_tools = blocking_load_tools # type: ignore[assignment] # ty: ignore[invalid-assignment] notification = Mock(spec=types.ServerNotification) notification.root = Mock() @@ -2259,7 +2259,7 @@ async def test_connect_sampling_capabilities_with_client(): mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session = AsyncMock() @@ -2289,7 +2289,7 @@ async def test_connect_no_sampling_capabilities_without_client(): mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session = AsyncMock() @@ -2318,7 +2318,7 @@ async def test_connect_session_creation_failure(): mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock ClientSession to raise an exception with patch("mcp.client.session.ClientSession") as mock_session_class: @@ -2341,7 +2341,7 @@ async def test_connect_initialization_failure_http_no_command(): mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock successful session creation but failed initialization mock_session = Mock() @@ -2364,28 +2364,28 @@ async def test_connect_cleanup_on_transport_failure(): tool = MCPStdioTool(name="test", command="test-command") # Mock _exit_stack.aclose to verify it's called - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock get_mcp_client to raise an exception - tool.get_mcp_client = Mock(side_effect=RuntimeError("Transport failed")) + tool.get_mcp_client = Mock(side_effect=RuntimeError("Transport failed")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(ToolException): await tool.connect() # Verify cleanup was called - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_connect_cleanup_on_transport_failure_http_uses_generic_message(): """Test HTTP transport failures use the generic connection message when no command exists.""" tool = MCPStreamableHTTPTool(name="test", url="https://example.com/mcp") - tool._exit_stack.aclose = AsyncMock() - tool.get_mcp_client = Mock(side_effect=RuntimeError("Transport failed")) + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] + tool.get_mcp_client = Mock(side_effect=RuntimeError("Transport failed")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(ToolException, match="Failed to connect to MCP server: Transport failed"): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_connect_cleanup_on_initialization_failure(): @@ -2393,14 +2393,14 @@ async def test_connect_cleanup_on_initialization_failure(): tool = MCPStdioTool(name="test", command="test-command") # Mock _exit_stack.aclose to verify it's called - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock successful transport creation mock_transport = (Mock(), Mock()) mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock successful session creation but failed initialization mock_session = Mock() @@ -2414,31 +2414,31 @@ async def test_connect_cleanup_on_initialization_failure(): await tool.connect() # Verify cleanup was called - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_connect_cancelled_error_during_transport_creation_raises_tool_exception(): """Test that CancelledError from transport creation is wrapped in ToolException.""" tool = MCPStreamableHTTPTool(name="test", url="http://example.com") - tool._exit_stack.aclose = AsyncMock() - tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(ToolException, match="Failed to connect to MCP server"): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_connect_cancelled_error_during_transport_creation_stdio_raises_tool_exception(): """Test that CancelledError from transport creation uses the command-specific message for MCPStdioTool.""" tool = MCPStdioTool(name="test", command="my-server") - tool._exit_stack.aclose = AsyncMock() - tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("cancel scope")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(ToolException, match="Failed to start MCP server 'my-server'"): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_connect_cancelled_error_during_session_creation_raises_tool_exception(): @@ -2449,7 +2449,7 @@ async def test_connect_cancelled_error_during_session_creation_raises_tool_excep mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session_class.return_value.__aenter__ = AsyncMock(side_effect=asyncio.CancelledError("cancel scope")) @@ -2472,7 +2472,7 @@ async def test_connect_cancelled_error_during_initialize_raises_tool_exception() mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_session = Mock() mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("Cancelled via cancel scope")) @@ -2493,7 +2493,7 @@ async def test_connect_cancelled_error_during_initialize_stdio_raises_tool_excep mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_session = Mock() mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("Cancelled via cancel scope")) @@ -2510,30 +2510,30 @@ async def test_connect_cancelled_error_during_initialize_stdio_raises_tool_excep async def test_connect_genuine_cancellation_during_transport_creation_propagates(): """Test that genuine task cancellation (task.cancelling() > 0) propagates as CancelledError.""" tool = MCPStreamableHTTPTool(name="test", url="http://example.com") - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_cancelled_task = Mock() mock_cancelled_task.cancelling.return_value = 1 with patch("asyncio.current_task", return_value=mock_cancelled_task): - tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("task cancelled")) + tool.get_mcp_client = Mock(side_effect=asyncio.CancelledError("task cancelled")) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(asyncio.CancelledError): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] @pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") async def test_connect_genuine_cancellation_during_initialize_propagates(): """Test that genuine task cancellation during initialize() propagates as CancelledError.""" tool = MCPStreamableHTTPTool(name="test", url="http://example.com") - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_transport = (Mock(), Mock()) mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_session = Mock() mock_session.initialize = AsyncMock(side_effect=asyncio.CancelledError("task cancelled")) @@ -2551,20 +2551,20 @@ async def test_connect_genuine_cancellation_during_initialize_propagates(): with pytest.raises(asyncio.CancelledError): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] @pytest.mark.skipif(sys.version_info < (3, 11), reason="task.cancelling() requires Python >= 3.11") async def test_connect_genuine_cancellation_during_session_creation_propagates(): """Test that genuine task cancellation during session creation propagates as CancelledError.""" tool = MCPStreamableHTTPTool(name="test", url="http://example.com") - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_transport = (Mock(), Mock()) mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] mock_cancelled_task = Mock() mock_cancelled_task.cancelling.return_value = 1 @@ -2579,7 +2579,7 @@ async def test_connect_genuine_cancellation_during_session_creation_propagates() with pytest.raises(asyncio.CancelledError): await tool.connect() - tool._exit_stack.aclose.assert_called_once() + tool._exit_stack.aclose.assert_called_once() # ty: ignore[unresolved-attribute] async def test_aenter_cancelled_error_during_connect_is_catchable_as_exception(): @@ -2597,7 +2597,7 @@ async def test_aenter_cancelled_error_during_connect_is_catchable_as_exception() mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) @@ -2650,7 +2650,7 @@ async def test_connect_cancelled_error_during_session_creation_includes_exceptio mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session_class.return_value.__aenter__ = AsyncMock( @@ -2673,7 +2673,7 @@ async def test_connect_cancelled_error_during_session_creation_logs_with_exc_inf mock_context_manager = Mock() mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport) mock_context_manager.__aexit__ = AsyncMock(return_value=None) - tool.get_mcp_client = Mock(return_value=mock_context_manager) + tool.get_mcp_client = Mock(return_value=mock_context_manager) # type: ignore[method-assign] # ty: ignore[invalid-assignment] with patch("mcp.client.session.ClientSession") as mock_session_class: mock_session_class.return_value.__aenter__ = AsyncMock(side_effect=asyncio.CancelledError("cancel scope")) @@ -2767,7 +2767,7 @@ async def test_mcp_tool_deduplication(): from agent_framework._tools import FunctionTool # Create MCPStreamableHTTPTool instance - tool = MCPTool(name="test_mcp_tool") + tool = MCPTool(name="test_mcp_tool") # type: ignore[abstract] # Manually set up functions list tool._functions = [] @@ -2828,7 +2828,7 @@ async def test_load_tools_prevents_multiple_calls(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Verify initial state assert tool._tools_loaded is False @@ -2867,7 +2867,7 @@ async def test_load_prompts_prevents_multiple_calls(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Verify initial state assert tool._prompts_loaded is False @@ -2962,7 +2962,7 @@ async def test_load_tools_with_pagination(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3038,7 +3038,7 @@ async def test_load_tools_adds_properties_to_zero_arg_tool_schema(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] mock_session = AsyncMock() tool.session = mock_session @@ -3126,7 +3126,7 @@ async def test_load_prompts_with_pagination(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3184,7 +3184,7 @@ async def test_load_tools_pagination_with_duplicates(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3247,7 +3247,7 @@ async def test_load_prompts_pagination_with_duplicates(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3301,7 +3301,7 @@ async def mock_list_prompts(params=None): async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta(): """Concurrent tool reloads should not duplicate functions or lose tools/list metadata.""" - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] mock_session = AsyncMock() tool.session = mock_session tool.load_tools_flag = True @@ -3333,7 +3333,7 @@ async def mock_list_tools(params: Any = None) -> Any: async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts(): """Concurrent prompt reloads should not duplicate functions.""" - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] mock_session = AsyncMock() tool.session = mock_session tool.load_prompts_flag = True @@ -3367,7 +3367,7 @@ async def test_load_tools_pagination_exception_handling(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3392,7 +3392,7 @@ async def test_load_prompts_pagination_exception_handling(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3417,7 +3417,7 @@ async def test_load_tools_empty_pagination(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3445,7 +3445,7 @@ async def test_load_prompts_empty_pagination(): from agent_framework._mcp import MCPTool - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] # Mock the session mock_session = AsyncMock() @@ -3495,7 +3495,7 @@ async def test_mcp_tool_connection_properly_invalidated_after_closed_resource_er # Mock _exit_stack.aclose to track cleanup calls original_exit_stack = tool._exit_stack - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock connect() to avoid trying to start actual process with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: @@ -3539,13 +3539,13 @@ async def call_tool_with_error(*args, **kwargs): assert mock_connect.call_count >= 1 mock_connect.assert_called_with(reset=True) # Verify _exit_stack.aclose was called during reconnection - original_exit_stack.aclose.assert_called() + original_exit_stack.aclose.assert_called() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Test Case 2: Reconnection failure # Reset counters call_count = 0 mock_connect.reset_mock() - original_exit_stack.aclose.reset_mock() + original_exit_stack.aclose.reset_mock() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Make call_tool always raise ClosedResourceError async def always_fail(*args, **kwargs): @@ -3594,7 +3594,7 @@ async def test_mcp_tool_get_prompt_reconnection_on_closed_resource_error(): # Mock _exit_stack.aclose to track cleanup calls original_exit_stack = tool._exit_stack - tool._exit_stack.aclose = AsyncMock() + tool._exit_stack.aclose = AsyncMock() # type: ignore[method-assign] # ty: ignore[invalid-assignment] # Mock connect() to avoid trying to start actual process with patch.object(tool, "connect", new_callable=AsyncMock) as mock_connect: @@ -3638,13 +3638,13 @@ async def get_prompt_with_error(*args, **kwargs): assert mock_connect.call_count >= 1 mock_connect.assert_called_with(reset=True) # Verify _exit_stack.aclose was called during reconnection - original_exit_stack.aclose.assert_called() + original_exit_stack.aclose.assert_called() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Test Case 2: Reconnection failure # Reset counters call_count = 0 mock_connect.reset_mock() - original_exit_stack.aclose.reset_mock() + original_exit_stack.aclose.reset_mock() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Make get_prompt always raise ClosedResourceError async def always_fail(*args, **kwargs): @@ -3666,14 +3666,14 @@ async def always_fail(*args, **kwargs): async def test_mcp_tool_call_tool_requires_loaded_tools() -> None: - tool = MCPTool(name="test_tool", load_tools=False) + tool = MCPTool(name="test_tool", load_tools=False) # type: ignore[abstract] with pytest.raises(ToolExecutionException, match="Tools are not loaded"): await tool.call_tool("remote_tool") async def test_mcp_tool_get_prompt_requires_loaded_prompts() -> None: - tool = MCPTool(name="test_tool", load_prompts=False) + tool = MCPTool(name="test_tool", load_prompts=False) # type: ignore[abstract] with pytest.raises(ToolExecutionException, match="Prompts are not loaded"): await tool.get_prompt("remote_prompt") @@ -3682,7 +3682,7 @@ async def test_mcp_tool_get_prompt_requires_loaded_prompts() -> None: async def test_mcp_tool_call_tool_raises_after_reconnection_still_fails() -> None: from anyio.streams.memory import ClosedResourceError - tool = MCPTool(name="test_tool", load_tools=True) + tool = MCPTool(name="test_tool", load_tools=True) # type: ignore[abstract] tool.session = Mock(call_tool=AsyncMock(side_effect=[ClosedResourceError(), ClosedResourceError()])) with ( @@ -3699,7 +3699,7 @@ async def test_mcp_tool_call_tool_raises_after_reconnection_still_fails() -> Non async def test_mcp_tool_get_prompt_raises_after_reconnection_still_fails() -> None: from anyio.streams.memory import ClosedResourceError - tool = MCPTool(name="test_tool", load_prompts=True) + tool = MCPTool(name="test_tool", load_prompts=True) # type: ignore[abstract] tool.session = Mock(get_prompt=AsyncMock(side_effect=[ClosedResourceError(), ClosedResourceError()])) with ( @@ -3714,7 +3714,7 @@ async def test_mcp_tool_get_prompt_raises_after_reconnection_still_fails() -> No async def test_mcp_tool_wraps_unexpected_call_tool_and_get_prompt_errors() -> None: - tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) # type: ignore[abstract] tool.session = Mock() tool.session.call_tool = AsyncMock(side_effect=RuntimeError("tool boom")) tool.session.get_prompt = AsyncMock(side_effect=RuntimeError("prompt boom")) @@ -3750,11 +3750,11 @@ def __init__(self) -> None: self.closed_cleanly = False async def __aenter__(self): - self.enter_task = asyncio.current_task() + self.enter_task = asyncio.current_task() # type: ignore[assignment] return (Mock(), Mock()) async def __aexit__(self, exc_type, exc, tb): - self.exit_task = asyncio.current_task() + self.exit_task = asyncio.current_task() # type: ignore[assignment] if self.exit_task is not self.enter_task: raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") self.closed_cleanly = True @@ -3802,11 +3802,11 @@ def __init__(self) -> None: self.closed_cleanly = False async def __aenter__(self): - self.enter_task = asyncio.current_task() + self.enter_task = asyncio.current_task() # type: ignore[assignment] return (Mock(), Mock()) async def __aexit__(self, exc_type, exc, tb): - self.exit_task = asyncio.current_task() + self.exit_task = asyncio.current_task() # type: ignore[assignment] if self.exit_task is not self.enter_task: raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in") self.closed_cleanly = True @@ -4137,7 +4137,7 @@ async def test_connect_handles_set_logging_level_exception(): async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompts() -> None: - tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) # type: ignore[abstract] tool.is_connected = True tool.session = Mock() tool.session._request_id = 0 @@ -4158,7 +4158,7 @@ async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompt async def test_connect_skips_tools_and_prompts_when_server_does_not_advertise_capabilities() -> None: - tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) # type: ignore[abstract] tool.is_connected = True tool.session = Mock() tool.session._request_id = 0 @@ -4189,7 +4189,7 @@ async def test_connect_skips_tools_and_prompts_when_server_does_not_advertise_ca async def test_connect_treats_missing_capabilities_as_unsupported() -> None: - tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) # type: ignore[abstract] tool.is_connected = True tool.session = Mock() tool.session._request_id = 0 @@ -4208,7 +4208,7 @@ async def test_connect_treats_missing_capabilities_as_unsupported() -> None: async def test_connect_sets_logging_level_when_server_advertises_logging() -> None: - tool = MCPTool(name="test_tool", load_tools=False, load_prompts=False) + tool = MCPTool(name="test_tool", load_tools=False, load_prompts=False) # type: ignore[abstract] tool.is_connected = True tool.session = Mock() tool.session._request_id = 0 @@ -4229,7 +4229,7 @@ async def test_connect_sets_logging_level_when_server_advertises_logging() -> No async def test_ensure_connected_skips_future_pings_when_ping_is_not_available() -> None: - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] tool.session = Mock( send_ping=AsyncMock( side_effect=McpError(types.ErrorData(code=-32601, message="Method 'ping' is not available.")) @@ -4246,7 +4246,7 @@ async def test_ensure_connected_skips_future_pings_when_ping_is_not_available() async def test_ensure_connected_reconnects_on_failed_ping() -> None: - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed"))) with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect: @@ -4256,7 +4256,7 @@ async def test_ensure_connected_reconnects_on_failed_ping() -> None: async def test_ensure_connected_wraps_reconnect_failure() -> None: - tool = MCPTool(name="test_tool") + tool = MCPTool(name="test_tool") # type: ignore[abstract] tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed"))) with ( @@ -4269,7 +4269,7 @@ async def test_ensure_connected_wraps_reconnect_failure() -> None: async def test_load_tools_reconnects_on_closed_resource_when_ping_is_unavailable() -> None: from anyio import ClosedResourceError - tool = MCPTool(name="test_tool", load_tools=True) + tool = MCPTool(name="test_tool", load_tools=True) # type: ignore[abstract] tool._ping_available = False first_session = Mock() @@ -4298,7 +4298,7 @@ async def reconnect() -> None: async def test_load_prompts_reconnects_on_closed_resource_when_ping_is_unavailable() -> None: from anyio import ClosedResourceError - tool = MCPTool(name="test_tool", load_prompts=True) + tool = MCPTool(name="test_tool", load_prompts=True) # type: ignore[abstract] tool._ping_available = False first_session = Mock() @@ -4333,7 +4333,7 @@ async def test_mcp_tool_filters_framework_kwargs(): """ class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4356,7 +4356,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] # Create a mock Pydantic model class to use as response_format class MockResponseFormat(BaseModel): @@ -4385,8 +4385,8 @@ class MockResponseFormat(BaseModel): ) # Verify call_tool was called with only the valid argument - server.session.call_tool.assert_called_once() - call_args = server.session.call_tool.call_args + server.session.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + call_args = server.session.call_tool.call_args # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Check that the arguments dict only contains 'param' and none of the framework kwargs arguments = call_args.kwargs.get("arguments", call_args[1] if len(call_args) > 1 else {}) @@ -4417,7 +4417,7 @@ async def test_mcp_tool_call_tool_otel_meta(use_span, expect_traceparent, span_e from opentelemetry import trace class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4439,7 +4439,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -4455,7 +4455,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): await server.call_tool("test_tool", param="test_value") - meta = server.session.call_tool.call_args.kwargs.get("meta") + meta = server.session.call_tool.call_args.kwargs.get("meta") # type: ignore[union-attr] # ty: ignore[unresolved-attribute] if expect_traceparent: # When a valid span is active, we expect some propagation fields to be injected, # but we do not assume any specific header name to keep this test propagator-agnostic. @@ -4478,7 +4478,7 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta(): } class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4502,7 +4502,7 @@ async def connect(self): self.session.list_prompts = AsyncMock(return_value=types.ListPromptsResult(prompts=[])) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -4512,7 +4512,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): await server.call_tool("WorkIQSharePoint.readSmallBinaryFile", fileId="file-1") - assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta + assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta(): @@ -4523,7 +4523,7 @@ async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta(): user_meta = {"from_user": "user-value", "shared": "user-value"} class TestServer(MCPTool): - async def connect(self) -> None: + async def connect(self) -> None: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4542,7 +4542,7 @@ async def connect(self) -> None: ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server") async with server: @@ -4551,7 +4551,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): await server.call_tool("test_tool", param="test_value", _meta=user_meta) - call_kwargs = server.session.call_tool.call_args.kwargs + call_kwargs = server.session.call_tool.call_args.kwargs # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert call_kwargs["arguments"] == {"param": "test_value"} assert call_kwargs["meta"] == { "from_tool": "tool-value", @@ -4580,7 +4580,7 @@ async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_ assert len(hooks) == 1, f"Expected exactly one hook, got {len(hooks)}" finally: if getattr(tool, "_httpx_client", None) is not None: - await tool._httpx_client.aclose() + await tool._httpx_client.aclose() # type: ignore[union-attr] # endregion @@ -4597,7 +4597,7 @@ async def test_mcp_streamable_http_tool_header_provider_injects_headers(): """ class _TestServer(MCPStreamableHTTPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4620,7 +4620,7 @@ async def connect(self): self.session.send_ping = AsyncMock() self.is_connected = True - def get_mcp_client(self): + def get_mcp_client(self): # pyrefly: ignore[bad-override] return None def provider(kwargs): @@ -4638,7 +4638,7 @@ def provider(kwargs): await server.call_tool("greet", name="Alice", some_token="my-secret") # Verify the MCP session.call_tool was called - server.session.call_tool.assert_called_once() + server.session.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_mcp_streamable_http_tool_header_provider_sets_contextvar(): @@ -4657,7 +4657,7 @@ async def spy_call_tool(self, tool_name, **kwargs): return await original_call_tool(self, tool_name, **kwargs) class _TestServer(MCPStreamableHTTPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4676,7 +4676,7 @@ async def connect(self): self.session.send_ping = AsyncMock() self.is_connected = True - def get_mcp_client(self): + def get_mcp_client(self): # pyrefly: ignore[bad-override] return None server = _TestServer( @@ -4699,7 +4699,7 @@ async def test_mcp_streamable_http_tool_header_provider_contextvar_reset_after_c from agent_framework._mcp import _mcp_call_headers class _TestServer(MCPStreamableHTTPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4718,7 +4718,7 @@ async def connect(self): self.session.send_ping = AsyncMock() self.is_connected = True - def get_mcp_client(self): + def get_mcp_client(self): # pyrefly: ignore[bad-override] return None server = _TestServer( @@ -4739,7 +4739,7 @@ async def test_mcp_streamable_http_tool_without_header_provider(): """Test that call_tool works normally when no header_provider is configured.""" class _TestServer(MCPStreamableHTTPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4758,7 +4758,7 @@ async def connect(self): self.session.send_ping = AsyncMock() self.is_connected = True - def get_mcp_client(self): + def get_mcp_client(self): # pyrefly: ignore[bad-override] return None server = _TestServer( @@ -4768,7 +4768,7 @@ def get_mcp_client(self): async with server: await server.load_tools() await server.call_tool("greet", name="Alice") - server.session.call_tool.assert_called_once() + server.session.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Without header_provider, call_tool should delegate directly to MCPTool assert server._header_provider is None @@ -4810,7 +4810,7 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): finally: # Ensure any created httpx client is properly closed if getattr(tool, "_httpx_client", None) is not None: - await tool._httpx_client.aclose() + await tool._httpx_client.aclose() # type: ignore[union-attr] async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect(): @@ -4846,7 +4846,7 @@ async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redir _mcp_call_headers.reset(token) finally: if getattr(tool, "_httpx_client", None) is not None: - await tool._httpx_client.aclose() + await tool._httpx_client.aclose() # type: ignore[union-attr] async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client(): @@ -4905,7 +4905,7 @@ async def spy_call_tool(self, tool_name, **kwargs): return result class _TestServer(MCPStreamableHTTPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -4928,7 +4928,7 @@ async def connect(self): self.session.send_ping = AsyncMock() self.is_connected = True - def get_mcp_client(self): + def get_mcp_client(self): # pyrefly: ignore[bad-override] return None provider_received: list[dict] = [] @@ -4965,8 +4965,8 @@ def provider(kwargs): assert provider_received[0]["some_token"] == "my-secret" # Verify session.call_tool was called with the tool arguments (not the runtime kwargs) - server.session.call_tool.assert_called_once() - call_args = server.session.call_tool.call_args + server.session.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + call_args = server.session.call_tool.call_args # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert call_args.kwargs.get("arguments", {}).get("name") == "Alice" @@ -4992,7 +4992,7 @@ def _make_task_snapshot( now = _utc_now() return types.GetTaskResult( taskId=task_id, - status=status, # type: ignore[arg-type] + status=status, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] statusMessage=status_message, createdAt=now, lastUpdatedAt=now, @@ -5030,7 +5030,7 @@ def _make_task_tool( ) -> MCPTool: from agent_framework import MCPTaskOptions - tool = MCPTool( + tool = MCPTool( # type: ignore[abstract] name="lro", task_options=task_options if task_options is not None else MCPTaskOptions(), ) @@ -5055,7 +5055,7 @@ def _send_request_dispatcher(*responses_by_method: tuple[str, Any]) -> Any: async def _dispatch(request: Any, _result_type: Any, *_args: Any, **_kw: Any) -> Any: method = getattr(request.root, "method", None) or getattr(request, "method", None) - queue = queues.get(method) + queue = queues.get(method) # type: ignore[arg-type, call-overload] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] if not queue: raise AssertionError(f"No mocked send_request response for method '{method}'.") item = queue.pop(0) @@ -5088,7 +5088,7 @@ async def test_task_options_rejects_non_positive_default_ttl() -> None: async def test_load_tools_captures_task_support() -> None: - tool = MCPTool(name="lro") + tool = MCPTool(name="lro") # type: ignore[abstract] tool.session = AsyncMock() tool.load_tools_flag = True @@ -5120,7 +5120,7 @@ async def test_call_tool_routes_required_through_task_lifecycle(monkeypatch: pyt monkeypatch.setattr(_mcp_module, "_MCP_TASK_MIN_POLL_INTERVAL", _mcp_module.timedelta(milliseconds=1)) tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="working")), @@ -5133,7 +5133,7 @@ async def test_call_tool_routes_required_through_task_lifecycle(monkeypatch: pyt assert _mcp_result_to_text(result) == "hello task" # Plain session.call_tool must NOT be used for required tools. - tool.session.call_tool.assert_not_called() # type: ignore[union-attr] + tool.session.call_tool.assert_not_called() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_call_tool_as_task_default_ttl_propagates() -> None: @@ -5156,7 +5156,7 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("ok") raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] await tool.call_tool("slow_op") @@ -5184,7 +5184,7 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("ok") raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] await tool.call_tool("slow_op") @@ -5197,10 +5197,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An async def test_call_tool_skips_task_path_for_optional_and_forbidden() -> None: for support in ("optional", "forbidden", None): tool = _make_task_tool(task_support=support) - tool.session.call_tool = AsyncMock( # type: ignore[union-attr] + tool.session.call_tool = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] return_value=types.CallToolResult(content=[types.TextContent(type="text", text="plain")]) ) - tool.session.send_request = AsyncMock(side_effect=AssertionError("task path should not be used")) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=AssertionError("task path should not be used")) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] result = await tool.call_tool("slow_op") assert _mcp_result_to_text(result) == "plain" @@ -5208,7 +5208,7 @@ async def test_call_tool_skips_task_path_for_optional_and_forbidden() -> None: async def test_call_tool_as_task_cancelled_status_raises() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="cancelled", status_message="server stop")), @@ -5221,7 +5221,7 @@ async def test_call_tool_as_task_cancelled_status_raises() -> None: async def test_call_tool_as_task_failed_status_raises() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="failed", status_message="boom")), @@ -5234,7 +5234,7 @@ async def test_call_tool_as_task_failed_status_raises() -> None: async def test_call_tool_as_task_input_required_raises() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="input_required", status_message="need more")), @@ -5247,7 +5247,7 @@ async def test_call_tool_as_task_input_required_raises() -> None: async def test_call_tool_as_task_payload_iserror_raises() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="completed")), @@ -5262,7 +5262,7 @@ async def test_call_tool_as_task_payload_iserror_raises() -> None: async def test_call_tool_as_task_malformed_payload_raises() -> None: tool = _make_task_tool() bad_payload = types.GetTaskPayloadResult.model_validate({"random": "stuff"}) - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result(task_id="abc")), ("tasks/get", _make_task_snapshot(task_id="abc", status="completed")), @@ -5276,25 +5276,25 @@ async def test_call_tool_as_task_malformed_payload_raises() -> None: async def test_call_tool_as_task_method_not_found_falls_back() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=McpError(types.ErrorData(code=types.METHOD_NOT_FOUND, message="no tasks here")) ) - tool.session.call_tool = AsyncMock( # type: ignore[union-attr] + tool.session.call_tool = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] return_value=types.CallToolResult(content=[types.TextContent(type="text", text="fell back")]) ) result = await tool.call_tool("slow_op") assert _mcp_result_to_text(result) == "fell back" - tool.session.call_tool.assert_awaited_once() # type: ignore[union-attr] + tool.session.call_tool.assert_awaited_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_call_tool_as_task_invalid_params_falls_back() -> None: tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=McpError(types.ErrorData(code=types.INVALID_PARAMS, message="unknown field")) ) - tool.session.call_tool = AsyncMock( # type: ignore[union-attr] + tool.session.call_tool = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] return_value=types.CallToolResult(content=[types.TextContent(type="text", text="plain ok")]) ) @@ -5312,13 +5312,13 @@ async def test_call_tool_as_task_legacy_calltoolresult_response_used_directly() }) tool = _make_task_tool() - tool.session.send_request = AsyncMock(return_value=legacy_payload) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(return_value=legacy_payload) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] result = await tool.call_tool("slow_op") assert _mcp_result_to_text(result) == "legacy ok" # Polling must not occur: a single tools/call was enough. - assert tool.session.send_request.call_count == 1 # type: ignore[union-attr] + assert tool.session.send_request.call_count == 1 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_call_tool_as_task_poll_interval_is_clamped(monkeypatch: pytest.MonkeyPatch) -> None: @@ -5335,7 +5335,7 @@ async def fake_sleep(delay: float) -> None: monkeypatch.setattr(_mcp_module.asyncio, "sleep", fake_sleep) tool = _make_task_tool() - tool.session.send_request = AsyncMock( # type: ignore[union-attr] + tool.session.send_request = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=_send_request_dispatcher( ("tools/call", _make_create_task_result()), ("tasks/get", _make_task_snapshot(status="working", poll_interval_ms=50)), # below 500ms min @@ -5382,10 +5382,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_task_snapshot(status="working") if method == "tasks/cancel": cancel_seen.set() - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] task = asyncio.create_task(tool.call_tool("slow_op")) await asyncio.wait_for(create_seen.wait(), timeout=1.0) @@ -5429,10 +5429,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_task_snapshot(status="working") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] task = asyncio.create_task(tool.call_tool("slow_op")) await asyncio.wait_for(create_seen.wait(), timeout=1.0) @@ -5472,7 +5472,7 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("recovered") raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] reconnect_calls = 0 @@ -5489,8 +5489,8 @@ async def fake_connect(reset: bool = False) -> None: # Critically, tools/call must NOT be re-issued after task_id is known. assert ( sum( - 1 - for c in tool.session.send_request.await_args_list # type: ignore[union-attr] + 1 # type: ignore[misc] + for c in tool.session.send_request.await_args_list # type: ignore[union-attr] # ty: ignore[unresolved-attribute] if c.args[0].root.method == "tools/call" ) == 1 @@ -5516,7 +5516,7 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An raise ClosedResourceError raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with ( patch.object(MCPTool, "connect", new=AsyncMock(return_value=None)), @@ -5542,7 +5542,7 @@ async def fake_send(_request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> A send_calls += 1 raise ClosedResourceError - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] reconnect_mock = AsyncMock(return_value=None) with ( @@ -5578,7 +5578,7 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("fetched after reconnect") raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] reconnect_calls = 0 @@ -5613,10 +5613,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An raise ClosedResourceError if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with ( patch.object(MCPTool, "connect", new=AsyncMock(return_value=None)), @@ -5637,14 +5637,14 @@ async def test_call_tool_as_task_create_unparseable_success_raises() -> None: unparseable = types.Result.model_validate({"foo": "bar"}) tool = _make_task_tool() - tool.session.send_request = AsyncMock(return_value=unparseable) # type: ignore[union-attr] - tool.session.call_tool = AsyncMock(return_value=types.CallToolResult(content=[])) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(return_value=unparseable) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] + tool.session.call_tool = AsyncMock(return_value=types.CallToolResult(content=[])) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="unparseable response"): await tool.call_tool("slow_op") # Critically: no plain tools/call fallback (would risk double execution). - tool.session.call_tool.assert_not_called() # type: ignore[union-attr] + tool.session.call_tool.assert_not_called() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_call_tool_as_task_max_wait_exceeded_raises_and_cancels(monkeypatch: pytest.MonkeyPatch) -> None: @@ -5666,10 +5666,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_task_snapshot(task_id="mw", status="working") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"): await tool.call_tool("slow_op") @@ -5707,10 +5707,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_task_snapshot(task_id="mw2", status="working") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="exceeded max_task_wait"): await tool.call_tool("slow_op") @@ -5749,10 +5749,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("recovered after transient") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] result = await tool.call_tool("slow_op") assert _mcp_result_to_text(result) == "recovered after transient" @@ -5784,10 +5784,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An raise McpError(types.ErrorData(code=types.INVALID_PARAMS, message="bad task id")) if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="bad task id"): await tool.call_tool("slow_op") @@ -5822,10 +5822,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return malformed if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="malformed tasks/get"): await tool.call_tool("slow_op") @@ -5855,10 +5855,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_task_snapshot(task_id="f1", status="failed", status_message="boom") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="task failed: boom"): await tool.call_tool("slow_op") @@ -5881,7 +5881,7 @@ async def test_try_cancel_task_logs_warning_on_timeout( async def hang(*_a: Any, **_kw: Any) -> Any: await asyncio.sleep(10.0) - tool.session.send_request = AsyncMock(side_effect=hang) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=hang) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with caplog.at_level(logging.WARNING, logger=_mcp_module.logger.name): await tool._try_cancel_task("hang-1") @@ -5896,7 +5896,7 @@ async def test_mcp_task_options_is_frozen() -> None: opts = MCPTaskOptions() with pytest.raises(FrozenInstanceError): - opts.default_ttl = timedelta(seconds=5) # type: ignore[misc] + opts.default_ttl = timedelta(seconds=5) # type: ignore[misc] # ty: ignore[invalid-assignment] async def test_mcp_task_options_max_task_wait_rejects_non_positive() -> None: @@ -5925,10 +5925,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An raise McpError(types.ErrorData(code=types.INTERNAL_ERROR, message="payload vanished")) if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(ToolExecutionException, match="payload vanished"): await tool.call_tool("slow_op") @@ -5966,10 +5966,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("ok") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(asyncio.TimeoutError): await tool.call_tool("slow_op") @@ -6011,10 +6011,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An return _make_payload("ok") if method == "tasks/cancel": cancel_called = True - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] with pytest.raises(asyncio.TimeoutError, match="inner parser timeout"): await tool.call_tool("slow_op") @@ -6038,10 +6038,10 @@ async def fake_send(request: Any, _result_type: Any, *_a: Any, **_kw: Any) -> An # Suggest a 5s poll interval (gets clamped to MAX=5s); wait_for must cut through it. return _make_task_snapshot(task_id="ds", status="working", poll_interval_ms=5000) if method == "tasks/cancel": - return types.CancelTaskResult() + return types.CancelTaskResult() # type: ignore[call-arg] # pyrefly: ignore[missing-argument] # ty: ignore[missing-argument] raise AssertionError(method) - tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[union-attr] + tool.session.send_request = AsyncMock(side_effect=fake_send) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] loop = asyncio.get_running_loop() started = loop.time() @@ -6103,7 +6103,7 @@ def test_normalize_additional_tool_argument_names_mapping_with_string_values() - def test_prepare_call_kwargs_strips_undeclared_arguments() -> None: - server = MCPTool(name="test_server") + server = MCPTool(name="test_server") # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": {"param"}} filtered, meta = server._prepare_call_kwargs( @@ -6116,7 +6116,7 @@ def test_prepare_call_kwargs_strips_undeclared_arguments() -> None: def test_prepare_call_kwargs_global_extras_allowed() -> None: - server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": {"param"}} filtered, _ = server._prepare_call_kwargs( @@ -6128,7 +6128,7 @@ def test_prepare_call_kwargs_global_extras_allowed() -> None: def test_prepare_call_kwargs_per_tool_and_global_extras() -> None: - server = MCPTool( + server = MCPTool( # type: ignore[abstract] name="test_server", additional_tool_argument_names={"*": ["conversation_id"], "test_tool": ["custom"]}, ) @@ -6152,7 +6152,7 @@ def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None: # The denylist is a safety net for framework-named params a server *declares* in its # schema: they are dropped so internal objects never leak. Names explicitly opted in # via extras always win. - server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": {"param", "thread"}} filtered, _ = server._prepare_call_kwargs( @@ -6166,7 +6166,7 @@ def test_prepare_call_kwargs_denylist_guards_server_declared_names() -> None: def test_prepare_call_kwargs_extras_override_denylist() -> None: # Opting a denylisted framework name back in via extras takes precedence over the # denylist safety net. "thread" is on the framework denylist, but an explicit extra wins. - server = MCPTool(name="test_server", additional_tool_argument_names=["thread"]) + server = MCPTool(name="test_server", additional_tool_argument_names=["thread"]) # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": {"param"}} sentinel = object() @@ -6180,7 +6180,7 @@ def test_prepare_call_kwargs_extras_override_denylist() -> None: def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None: - server = MCPTool(name="test_server") + server = MCPTool(name="test_server") # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": set()} filtered, _ = server._prepare_call_kwargs( @@ -6191,7 +6191,7 @@ def test_prepare_call_kwargs_zero_arg_tool_passes_no_arguments() -> None: def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None: - server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) + server = MCPTool(name="test_server", additional_tool_argument_names=["conversation_id"]) # type: ignore[abstract] # No entry in _tool_param_names_by_name for this tool name. filtered, _ = server._prepare_call_kwargs( @@ -6202,7 +6202,7 @@ def test_prepare_call_kwargs_unknown_tool_passes_only_global_extras() -> None: def test_prepare_call_kwargs_extracts_meta() -> None: - server = MCPTool(name="test_server") + server = MCPTool(name="test_server") # type: ignore[abstract] server._tool_param_names_by_name = {"test_tool": {"param"}} filtered, meta = server._prepare_call_kwargs( @@ -6218,7 +6218,7 @@ async def test_call_tool_forwards_only_declared_arguments() -> None: """End-to-end: framework runtime kwargs are stripped before reaching the server.""" class TestServer(MCPTool): - async def connect(self): + async def connect(self): # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] self.session = Mock(spec=ClientSession) self.session.list_tools = AsyncMock( return_value=types.ListToolsResult( @@ -6240,7 +6240,7 @@ async def connect(self): ) def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: - return None + return None # type: ignore[return-value] # pyrefly: ignore[bad-return] # ty: ignore[invalid-return-type] server = TestServer(name="test_server", additional_tool_argument_names=["conversation_id"]) async with server: @@ -6254,8 +6254,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: response_format=object(), ) - session_mock.call_tool.assert_called_once() - _, call_kwargs = session_mock.call_tool.call_args + session_mock.call_tool.assert_called_once() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + _, call_kwargs = session_mock.call_tool.call_args # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert call_kwargs["arguments"] == {"param": "value", "conversation_id": "c"} diff --git a/python/packages/core/tests/core/test_mcp_observability.py b/python/packages/core/tests/core/test_mcp_observability.py index 226e9761205..2d7136bfadf 100644 --- a/python/packages/core/tests/core/test_mcp_observability.py +++ b/python/packages/core/tests/core/test_mcp_observability.py @@ -32,7 +32,7 @@ def _make_connected_mcp_tool( supports_prompts: bool = True, ) -> MCPTool: """Create an MCPTool with a mocked session, ready for testing.""" - tool = MCPTool(name=name) + tool = MCPTool(name=name) # type: ignore[abstract] tool.session = AsyncMock() tool.is_connected = True tool._supports_tools = supports_tools @@ -100,7 +100,7 @@ def _make_get_prompt_result(text: str = "prompt result") -> types.GetPromptResul async def test_mcp_initialize_span(span_exporter: InMemorySpanExporter): """session.initialize() should produce an MCP CLIENT span named 'initialize'.""" - tool = MCPTool(name="test-server") + tool = MCPTool(name="test-server") # type: ignore[abstract] mock_session_cls = AsyncMock() init_result = Mock() @@ -146,8 +146,8 @@ async def patched_connect(self_: Any, *, reset: bool = False, load_configured: b assert len(init_spans) == 1 span = init_spans[0] assert span.kind == SpanKind.CLIENT - assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "initialize" - assert span.attributes.get(OtelAttr.MCP_PROTOCOL_VERSION) == "2025-06-18" + assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "initialize" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes.get(OtelAttr.MCP_PROTOCOL_VERSION) == "2025-06-18" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # endregion @@ -159,7 +159,7 @@ async def patched_connect(self_: Any, *, reset: bool = False, load_configured: b async def test_mcp_tools_list_span(span_exporter: InMemorySpanExporter): """session.list_tools() should produce an MCP CLIENT span named 'tools/list'.""" tool = _make_connected_mcp_tool() - tool.session.list_tools = AsyncMock(return_value=_make_tool_list_result()) + tool.session.list_tools = AsyncMock(return_value=_make_tool_list_result()) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() await tool.load_tools() @@ -169,7 +169,7 @@ async def test_mcp_tools_list_span(span_exporter: InMemorySpanExporter): assert len(list_spans) == 1 span = list_spans[0] assert span.kind == SpanKind.CLIENT - assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "tools/list" + assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "tools/list" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # endregion @@ -181,7 +181,7 @@ async def test_mcp_tools_list_span(span_exporter: InMemorySpanExporter): async def test_mcp_prompts_list_span(span_exporter: InMemorySpanExporter): """session.list_prompts() should produce an MCP CLIENT span named 'prompts/list'.""" tool = _make_connected_mcp_tool() - tool.session.list_prompts = AsyncMock(return_value=_make_prompt_list_result()) + tool.session.list_prompts = AsyncMock(return_value=_make_prompt_list_result()) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() await tool.load_prompts() @@ -191,7 +191,7 @@ async def test_mcp_prompts_list_span(span_exporter: InMemorySpanExporter): assert len(list_spans) == 1 span = list_spans[0] assert span.kind == SpanKind.CLIENT - assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "prompts/list" + assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "prompts/list" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # endregion @@ -203,7 +203,7 @@ async def test_mcp_prompts_list_span(span_exporter: InMemorySpanExporter): async def test_mcp_tools_call_creates_client_span_when_no_parent(span_exporter: InMemorySpanExporter): """Direct call_tool() without FunctionTool wrapper creates new MCP CLIENT span.""" tool = _make_connected_mcp_tool() - tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("hello")) + tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("hello")) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() result = await tool.call_tool("get-weather", city="Seattle") @@ -215,14 +215,14 @@ async def test_mcp_tools_call_creates_client_span_when_no_parent(span_exporter: span = call_spans[0] assert span.kind == SpanKind.CLIENT assert span.name == "tools/call get-weather" - assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "tools/call" - assert span.attributes[OtelAttr.TOOL_NAME] == "get-weather" + assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "tools/call" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_NAME] == "get-weather" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_mcp_tools_call_tool_error_sets_error_type(span_exporter: InMemorySpanExporter): """When CallToolResult.isError is true, error.type should be 'tool_error' per MCP spec.""" tool = _make_connected_mcp_tool() - tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("bad input", is_error=True)) + tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("bad input", is_error=True)) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() with pytest.raises(ToolExecutionException): @@ -232,14 +232,14 @@ async def test_mcp_tools_call_tool_error_sets_error_type(span_exporter: InMemory call_spans = [s for s in spans if "tools/call" in s.name] assert len(call_spans) == 1 span = call_spans[0] - assert span.attributes.get(OtelAttr.ERROR_TYPE) == "tool_error" + assert span.attributes.get(OtelAttr.ERROR_TYPE) == "tool_error" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert span.status.status_code == StatusCode.ERROR async def test_mcp_tools_call_mcp_error_sets_error_type(span_exporter: InMemorySpanExporter): """When session.call_tool() raises McpError, error.type should be the exception class name.""" tool = _make_connected_mcp_tool() - tool.session.call_tool = AsyncMock(side_effect=McpError(ErrorData(code=-32600, message="invalid request"))) + tool.session.call_tool = AsyncMock(side_effect=McpError(ErrorData(code=-32600, message="invalid request"))) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() with pytest.raises(ToolExecutionException): @@ -249,7 +249,7 @@ async def test_mcp_tools_call_mcp_error_sets_error_type(span_exporter: InMemoryS call_spans = [s for s in spans if "tools/call" in s.name] assert len(call_spans) == 1 span = call_spans[0] - assert span.attributes.get(OtelAttr.ERROR_TYPE) == "McpError" + assert span.attributes.get(OtelAttr.ERROR_TYPE) == "McpError" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert span.status.status_code == StatusCode.ERROR @@ -262,7 +262,7 @@ async def test_mcp_tools_call_mcp_error_sets_error_type(span_exporter: InMemoryS async def test_mcp_prompts_get_creates_client_span(span_exporter: InMemorySpanExporter): """get_prompt() should always create a new MCP CLIENT span (not enrich execute_tool).""" tool = _make_connected_mcp_tool() - tool.session.get_prompt = AsyncMock(return_value=_make_get_prompt_result("code analysis")) + tool.session.get_prompt = AsyncMock(return_value=_make_get_prompt_result("code analysis")) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() result = await tool.get_prompt("analyze-code", language="python") @@ -274,14 +274,14 @@ async def test_mcp_prompts_get_creates_client_span(span_exporter: InMemorySpanEx span = prompt_spans[0] assert span.kind == SpanKind.CLIENT assert span.name == "prompts/get analyze-code" - assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "prompts/get" - assert span.attributes[OtelAttr.PROMPT_NAME] == "analyze-code" + assert span.attributes[OtelAttr.MCP_METHOD_NAME] == "prompts/get" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.PROMPT_NAME] == "analyze-code" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_mcp_prompts_get_mcp_error_sets_error_type(span_exporter: InMemorySpanExporter): """When session.get_prompt() raises McpError, the span should have error.type and ERROR status.""" tool = _make_connected_mcp_tool() - tool.session.get_prompt = AsyncMock( + tool.session.get_prompt = AsyncMock( # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] side_effect=McpError(ErrorData(code=-32602, message="prompt not found")) ) @@ -293,7 +293,7 @@ async def test_mcp_prompts_get_mcp_error_sets_error_type(span_exporter: InMemory prompt_spans = [s for s in spans if "prompts/get" in s.name] assert len(prompt_spans) == 1 span = prompt_spans[0] - assert span.attributes.get(OtelAttr.ERROR_TYPE) == "McpError" + assert span.attributes.get(OtelAttr.ERROR_TYPE) == "McpError" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert span.status.status_code == StatusCode.ERROR @@ -362,8 +362,8 @@ def test_mcp_websocket_tool_default_port(): async def test_mcp_spans_not_created_when_observability_disabled(span_exporter: InMemorySpanExporter): """No MCP spans should be created when observability is disabled.""" tool = _make_connected_mcp_tool() - tool.session.list_tools = AsyncMock(return_value=_make_tool_list_result()) - tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("ok")) + tool.session.list_tools = AsyncMock(return_value=_make_tool_list_result()) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] + tool.session.call_tool = AsyncMock(return_value=_make_call_tool_result("ok")) # type: ignore[method-assign, union-attr] # ty: ignore[invalid-assignment] span_exporter.clear() await tool.load_tools() diff --git a/python/packages/core/tests/core/test_mcp_skills.py b/python/packages/core/tests/core/test_mcp_skills.py index 3e7c67662af..74993997d0e 100644 --- a/python/packages/core/tests/core/test_mcp_skills.py +++ b/python/packages/core/tests/core/test_mcp_skills.py @@ -35,26 +35,22 @@ Body content here. """ -SAMPLE_SKILL_INDEX = json.dumps( - { - "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", - "skills": [ - { - "name": "unit-converter", - "type": "skill-md", - "description": "Convert between common units.", - "url": "skill://unit-converter/SKILL.md", - } - ], - } -) +SAMPLE_SKILL_INDEX = json.dumps({ + "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", + "skills": [ + { + "name": "unit-converter", + "type": "skill-md", + "description": "Convert between common units.", + "url": "skill://unit-converter/SKILL.md", + } + ], +}) def _make_text_result(text: str, uri: str = "skill://test") -> ReadResourceResult: """Create a ReadResourceResult with a single TextResourceContents.""" - return ReadResourceResult( - contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")] - ) + return ReadResourceResult(contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")]) def _make_blob_result( @@ -230,12 +226,10 @@ async def test_get_content_raises_on_empty(self) -> None: @pytest.mark.asyncio async def test_get_resource_text(self) -> None: - client = _make_client( - **{ - "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), - "skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"), - } - ) + client = _make_client(**{ + "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), + "skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"), + }) from agent_framework import SkillFrontmatter fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.") @@ -249,12 +243,10 @@ async def test_get_resource_text(self) -> None: @pytest.mark.asyncio async def test_get_resource_binary(self) -> None: data = bytes([0x01, 0x02, 0x03, 0x04]) - client = _make_client( - **{ - "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), - "skill://unit-converter/assets/icon.bin": _make_blob_result(data), - } - ) + client = _make_client(**{ + "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), + "skill://unit-converter/assets/icon.bin": _make_blob_result(data), + }) from agent_framework import SkillFrontmatter fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.") @@ -345,12 +337,10 @@ class TestMCPSkillsSource: @pytest.mark.asyncio async def test_index_based_discovery_returns_skill(self) -> None: - client = _make_client( - **{ - "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), - "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), - } - ) + client = _make_client(**{ + "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), + "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), + }) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -373,9 +363,7 @@ async def test_no_index_returns_empty(self) -> None: async def test_does_not_read_skill_md_during_discovery(self) -> None: # Index points to a skill, but SKILL.md is not registered on the server. # Discovery should succeed because it only reads the index. - client = _make_client( - **{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")} - ) + client = _make_client(**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -384,19 +372,17 @@ async def test_does_not_read_skill_md_during_discovery(self) -> None: @pytest.mark.asyncio async def test_invalid_name_is_skipped(self) -> None: - index_json = json.dumps( - { - "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", - "skills": [ - { - "name": "UnitConverter", # Invalid: uppercase - "type": "skill-md", - "description": "Convert between common units.", - "url": "skill://UnitConverter/SKILL.md", - } - ], - } - ) + index_json = json.dumps({ + "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", + "skills": [ + { + "name": "UnitConverter", # Invalid: uppercase + "type": "skill-md", + "description": "Convert between common units.", + "url": "skill://UnitConverter/SKILL.md", + } + ], + }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -404,18 +390,16 @@ async def test_invalid_name_is_skipped(self) -> None: @pytest.mark.asyncio async def test_missing_required_fields_is_skipped(self) -> None: - index_json = json.dumps( - { - "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", - "skills": [ - { - "name": "unit-converter", - "type": "skill-md", - # Missing description and url - } - ], - } - ) + index_json = json.dumps({ + "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", + "skills": [ + { + "name": "unit-converter", + "type": "skill-md", + # Missing description and url + } + ], + }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -423,19 +407,17 @@ async def test_missing_required_fields_is_skipped(self) -> None: @pytest.mark.asyncio async def test_unsupported_type_is_skipped(self) -> None: - index_json = json.dumps( - { - "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", - "skills": [ - { - "name": "some-skill", - "type": "archive", - "description": "Packaged skill.", - "url": "skill://some-skill.tar.gz", - } - ], - } - ) + index_json = json.dumps({ + "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", + "skills": [ + { + "name": "some-skill", + "type": "archive", + "description": "Packaged skill.", + "url": "skill://some-skill.tar.gz", + } + ], + }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -443,18 +425,16 @@ async def test_unsupported_type_is_skipped(self) -> None: @pytest.mark.asyncio async def test_template_type_is_skipped(self) -> None: - index_json = json.dumps( - { - "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", - "skills": [ - { - "type": "mcp-resource-template", - "description": "Per-product documentation skill", - "url": "skill://docs/{product}/SKILL.md", - } - ], - } - ) + index_json = json.dumps({ + "$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json", + "skills": [ + { + "type": "mcp-resource-template", + "description": "Per-product documentation skill", + "url": "skill://docs/{product}/SKILL.md", + } + ], + }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() @@ -462,31 +442,25 @@ async def test_template_type_is_skipped(self) -> None: @pytest.mark.asyncio async def test_empty_index_returns_empty(self) -> None: - client = _make_client( - **{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")} - ) + client = _make_client(**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() assert skills == [] @pytest.mark.asyncio async def test_malformed_index_json_returns_empty(self) -> None: - client = _make_client( - **{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")} - ) + client = _make_client(**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")}) source = MCPSkillsSource(client=client) skills = await source.get_skills() assert skills == [] @pytest.mark.asyncio async def test_sibling_text_resource(self) -> None: - client = _make_client( - **{ - "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), - "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), - "skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"), - } - ) + client = _make_client(**{ + "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), + "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), + "skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"), + }) source = MCPSkillsSource(client=client) skill = (await source.get_skills())[0] resource = await skill.get_resource("references/checklist.md") @@ -497,13 +471,11 @@ async def test_sibling_text_resource(self) -> None: @pytest.mark.asyncio async def test_sibling_binary_resource(self) -> None: data = bytes([0x01, 0x02, 0x03, 0x04]) - client = _make_client( - **{ - "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), - "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), - "skill://unit-converter/assets/icon.bin": _make_blob_result(data), - } - ) + client = _make_client(**{ + "skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"), + "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), + "skill://unit-converter/assets/icon.bin": _make_blob_result(data), + }) source = MCPSkillsSource(client=client) skill = (await source.get_skills())[0] resource = await skill.get_resource("assets/icon.bin") @@ -649,9 +621,7 @@ async def test_get_resource_generic_mcp_error_propagates(self) -> None: from agent_framework import SkillFrontmatter client = AsyncMock() - client.read_resource = AsyncMock( - side_effect=McpError(error=ErrorData(code=0, message="Handler error")) - ) + client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=0, message="Handler error"))) fm = SkillFrontmatter(name="test-skill", description="Test.") skill = MCPSkill(frontmatter=fm, skill_md_uri="skill://test/SKILL.md", client=client) with pytest.raises(McpError): diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index ac08199b4a4..8522adde923 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -220,9 +220,9 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context, final_handler) + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] if stream is not None: - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -257,8 +257,8 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context, final_handler) - async for update in stream: + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -298,8 +298,8 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: response = await pipeline.execute(context, final_handler) assert response is not None - assert len(response.messages) == 1 - assert response.messages[0].text == "response" + assert len(response.messages) == 1 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert response.messages[0].text == "response" # type: ignore[union-attr] # pyrefly: ignore[bad-index] # ty: ignore[unresolved-attribute] assert execution_order == ["handler"] async def test_execute_stream_with_pre_next_termination(self, mock_agent: SupportsAgentRun) -> None: @@ -321,9 +321,9 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context, final_handler) + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] if stream is not None: - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) # Handler should not be called when terminated before next() @@ -348,8 +348,8 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context, final_handler) - async for update in stream: + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -366,7 +366,7 @@ async def test_execute_with_session_in_context(self, mock_agent: SupportsAgentRu class SessionCapturingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_session - captured_session = context.session + captured_session = context.session # type: ignore[assignment] await call_next() middleware = SessionCapturingMiddleware() @@ -391,7 +391,7 @@ async def test_execute_with_no_session_in_context(self, mock_agent: SupportsAgen class SessionCapturingMiddleware(AgentMiddleware): async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: nonlocal captured_session - captured_session = context.session + captured_session = context.session # type: ignore[assignment] await call_next() middleware = SessionCapturingMiddleware() @@ -620,7 +620,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: updates: list[ChatResponseUpdate] = [] stream = await pipeline.execute(context, final_handler) - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -657,7 +657,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: updates: list[ChatResponseUpdate] = [] stream = await pipeline.execute(context, final_handler) - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -699,8 +699,8 @@ async def final_handler(ctx: ChatContext) -> ChatResponse: response = await pipeline.execute(context, final_handler) assert response is not None - assert len(response.messages) == 1 - assert response.messages[0].text == "response" + assert len(response.messages) == 1 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert response.messages[0].text == "response" # type: ignore[union-attr] # pyrefly: ignore[bad-index] # ty: ignore[unresolved-attribute] assert execution_order == ["handler"] async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: Any) -> None: @@ -748,7 +748,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: updates: list[ChatResponseUpdate] = [] stream = await pipeline.execute(context, final_handler) - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) assert len(updates) == 2 @@ -1261,8 +1261,8 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context_stream, final_stream_handler) - async for update in stream: + stream = await pipeline.execute(context_stream, final_stream_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) # Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler] @@ -1295,8 +1295,8 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[str] = [] - stream = await pipeline.execute(context, final_stream_handler) - async for update in stream: + stream = await pipeline.execute(context, final_stream_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update.text) assert updates == ["chunk1", "chunk2"] @@ -1344,7 +1344,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: updates: list[ChatResponseUpdate] = [] stream = await pipeline.execute(context_stream, final_stream_handler) - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) # Verify flags: [non-streaming middleware, non-streaming handler, streaming middleware, streaming handler] @@ -1379,7 +1379,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: updates: list[str] = [] stream = await pipeline.execute(context, final_stream_handler) - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update.text) assert updates == ["chunk1", "chunk2"] @@ -1483,7 +1483,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) # When middleware doesn't call next(), result is None - stream = await pipeline.execute(context, final_handler) + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Verify no execution happened - result is None since middleware didn't set it assert stream is None @@ -1613,7 +1613,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: try: stream = await pipeline.execute(context, final_handler) if stream is not None: - async for update in stream: + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) except ValueError: # Expected - streaming middleware requires a ResponseStream result but middleware didn't call next() @@ -1723,7 +1723,7 @@ def test_categorize_middleware_with_single_item(self) -> None: def test_categorize_middleware_with_string_does_not_decompose(self) -> None: """Test that a string is not decomposed character-by-character.""" - result = categorize_middleware("not_a_middleware") + result = categorize_middleware("not_a_middleware") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # String should be treated as a single item, not decomposed into characters total_items = len(result["chat"]) + len(result["function"]) + len(result["agent"]) assert total_items == 1 diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index 303feb0344c..75ebc9e2204 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -64,7 +64,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: # Verify the overridden response is returned assert result is not None assert result == override_response - assert result.messages[0].text == "overridden response" + assert result.messages[0].text == "overridden response" # ty: ignore[unresolved-attribute] # type: ignore[union-attr] # Verify original handler was called since middleware called next() assert handler_called @@ -93,8 +93,8 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return ResponseStream(_stream()) updates: list[AgentResponseUpdate] = [] - stream = await pipeline.execute(context, final_handler) - async for update in stream: + stream = await pipeline.execute(context, final_handler) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + async for update in stream: # type: ignore[attr-defined, union-attr] # pyrefly: ignore[not-iterable] # ty: ignore[not-iterable] updates.append(update) # Verify the overridden response stream is returned @@ -151,7 +151,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with override middleware middleware = ChatAgentResponseOverrideMiddleware() - agent = Agent(client=mock_chat_client, middleware=[middleware]) + agent = Agent(client=mock_chat_client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Test override case override_messages = [Message(role="user", contents=["Give me a special response"])] @@ -187,7 +187,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with override middleware middleware = ChatAgentStreamOverrideMiddleware() - agent = Agent(client=mock_chat_client, middleware=[middleware]) + agent = Agent(client=mock_chat_client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Test streaming override case override_messages = [Message(role="user", contents=["Give me a custom stream"])] @@ -248,7 +248,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: execute_result = await pipeline.execute(execute_context, final_handler) assert execute_result is not None - assert execute_result.messages[0].text == "executed response" + assert execute_result.messages[0].text == "executed response" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert handler_called async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool) -> None: @@ -399,7 +399,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: # Verify response was modified after execution assert result is not None - assert result.messages[0].text == "modified after execution" + assert result.messages[0].text == "modified after execution" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool) -> None: """Test that middleware can override function result after observing execution.""" diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 9f567687700..0eadb1b5afb 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -81,7 +81,7 @@ async def process( await call_next() middleware = TrackingFunctionMiddleware() - Agent(client=client, middleware=[middleware]) + Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async def test_class_based_function_middleware_with_chat_agent_supported_client( self, chat_client_base: "MockBaseChatClient" @@ -131,7 +131,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with terminating middleware middleware = PreTerminationMiddleware() - agent = Agent(client=client, middleware=[middleware]) + agent = Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Execute the agent with multiple messages messages = [ @@ -155,11 +155,11 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable execution_order.append("middleware_before") await call_next() execution_order.append("middleware_after") - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Create Agent with terminating middleware middleware = PostTerminationMiddleware() - agent = Agent(client=client, middleware=[middleware]) + agent = Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Execute the agent with multiple messages messages = [ @@ -192,12 +192,12 @@ async def process( call_next: Callable[[], Awaitable[None]], ) -> None: execution_order.append("middleware_before") - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # We call next() but since terminate=True, subsequent middleware and handler should not execute await call_next() execution_order.append("middleware_after") - Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[]) + Agent(client=client, middleware=[PreTerminationFunctionMiddleware()], tools=[]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async def test_function_middleware_with_post_termination(self, client: "MockChatClient") -> None: """Test that function middleware can terminate execution after calling next().""" @@ -212,9 +212,9 @@ async def process( execution_order.append("middleware_before") await call_next() execution_order.append("middleware_after") - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] - Agent(client=client, middleware=[PostTerminationFunctionMiddleware()], tools=[]) + Agent(client=client, middleware=[PostTerminationFunctionMiddleware()], tools=[]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async def test_function_based_agent_middleware_with_chat_agent(self, client: "MockChatClient") -> None: """Test function-based agent middleware with Agent.""" @@ -226,7 +226,7 @@ async def tracking_agent_middleware(context: AgentContext, call_next: Callable[[ execution_order.append("agent_function_after") # Create Agent with function middleware - agent = Agent(client=client, middleware=[tracking_agent_middleware]) + agent = Agent(client=client, middleware=[tracking_agent_middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Execute the agent messages = [Message(role="user", contents=["test message"])] @@ -250,7 +250,7 @@ async def tracking_function_middleware( ) -> None: await call_next() - Agent(client=client, middleware=[tracking_function_middleware]) + Agent(client=client, middleware=[tracking_function_middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async def test_function_based_function_middleware_with_supported_client( self, chat_client_base: "MockBaseChatClient" @@ -292,7 +292,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with middleware middleware = StreamingTrackingMiddleware() - agent = Agent(client=client, middleware=[middleware]) + agent = Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Set up mock streaming responses client.streaming_responses = [ @@ -332,7 +332,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with middleware middleware = FlagTrackingMiddleware() - agent = Agent(client=client, middleware=[middleware]) + agent = Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] messages = [Message(role="user", contents=["test message"])] # Test non-streaming execution @@ -369,7 +369,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable middleware3 = OrderedMiddleware("third") # Create Agent with multiple middleware - agent = Agent(client=client, middleware=[middleware1, middleware2, middleware3]) + agent = Agent(client=client, middleware=[middleware1, middleware2, middleware3]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Execute the agent messages = [Message(role="user", contents=["test message"])] @@ -796,7 +796,7 @@ async def kwargs_middleware( # Execute the agent with custom parameters passed as kwargs messages = [Message(role="user", contents=["test message"])] - response = await agent.run(messages, options={"additional_function_arguments": {"custom_param": "test_value"}}) + response = await agent.run(messages, options={"additional_function_arguments": {"custom_param": "test_value"}}) # type: ignore[call-overload, typeddict-unknown-key, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] # Verify response assert response is not None @@ -890,13 +890,13 @@ async def capture_middleware( agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function]) - await agent.run( + await agent.run( # type: ignore[call-overload] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] [Message(role="user", contents=["Get weather"])], function_invocation_kwargs={ "user_id": "from-kwargs", "tenant_id": "from-kwargs", }, - options={ + options={ # type: ignore[typeddict-unknown-key] "additional_function_arguments": { "user_id": "from-options", "extra_key": "only-in-options", @@ -1016,7 +1016,7 @@ async def test_middleware_dynamic_rebuild_non_streaming(self, client: "MockChatC # Create agent with initial middleware middleware1 = self.TrackingAgentMiddleware("middleware1", execution_log) - agent = Agent(client=client, middleware=[middleware1]) + agent = Agent(client=client, middleware=[middleware1]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # First execution - should use middleware1 await agent.run("Test message 1") @@ -1066,7 +1066,7 @@ async def test_middleware_dynamic_rebuild_streaming(self, client: "MockChatClien # Create agent with initial middleware middleware1 = self.TrackingAgentMiddleware("stream_middleware1", execution_log) - agent = Agent(client=client, middleware=[middleware1]) + agent = Agent(client=client, middleware=[middleware1]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # First streaming execution updates: list[AgentResponseUpdate] = [] @@ -1101,7 +1101,7 @@ async def test_middleware_order_change_detection(self, client: "MockChatClient") middleware2 = self.TrackingAgentMiddleware("second", execution_log) # Create agent with middleware in order [first, second] - agent = Agent(client=client, middleware=[middleware1, middleware2]) + agent = Agent(client=client, middleware=[middleware1, middleware2]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # First execution await agent.run("Test message 1") @@ -1138,7 +1138,7 @@ async def test_run_level_middleware_isolation(self, client: "MockChatClient") -> execution_log: list[str] = [] # Create agent without any agent-level middleware - agent = Agent(client=client) + agent = Agent(client=client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Create run-level middleware run_middleware1 = self.TrackingAgentMiddleware("run1", execution_log) @@ -1203,7 +1203,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create agent with agent-level middleware agent_middleware = MetadataAgentMiddleware("agent") - agent = Agent(client=client, middleware=[agent_middleware]) + agent = Agent(client=client, middleware=[agent_middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Create run-level middleware run_middleware = MetadataRunMiddleware("run") @@ -1223,7 +1223,7 @@ async def test_run_level_middleware_non_streaming(self, client: "MockChatClient" execution_log: list[str] = [] # Create agent without agent-level middleware - agent = Agent(client=client) + agent = Agent(client=client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Create run-level middleware run_middleware = self.TrackingAgentMiddleware("run_nonstream", execution_log) @@ -1256,7 +1256,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable execution_log.append(f"{self.name}_end") # Create agent without agent-level middleware - agent = Agent(client=client) + agent = Agent(client=client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Set up mock streaming responses client.streaming_responses = [ @@ -1447,7 +1447,7 @@ def custom_tool(message: str) -> str: ] ) final_response = ChatResponse(messages=[Message(role="assistant", contents=["Final response"])]) - chat_client_base.responses = [function_call_response, final_response] + chat_client_base.responses = [function_call_response, final_response] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Should work without errors agent = Agent( @@ -1469,14 +1469,14 @@ async def test_decorator_and_type_mismatch(self, client: MockChatClient) -> None # Should raise MiddlewareException due to mismatch during agent creation with pytest.raises(MiddlewareException, match="MiddlewareTypes type mismatch"): - @agent_middleware # type: ignore[arg-type] + @agent_middleware # type: ignore[arg-type] # ty: ignore[invalid-argument-type] async def mismatched_middleware( context: FunctionInvocationContext, # Wrong type for @agent_middleware call_next: Any, ) -> None: await call_next() - agent = Agent(client=client, middleware=[mismatched_middleware]) + agent = Agent(client=client, middleware=[mismatched_middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] await agent.run([Message(role="user", contents=["test"])]) async def test_only_decorator_specified(self, chat_client_base: "MockBaseChatClient") -> None: @@ -1518,7 +1518,7 @@ def custom_tool(message: str) -> str: ] ) final_response = ChatResponse(messages=[Message(role="assistant", contents=["Final response"])]) - chat_client_base.responses = [function_call_response, final_response] + chat_client_base.responses = [function_call_response, final_response] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Should work - relies on decorator agent = Agent( @@ -1574,7 +1574,7 @@ def custom_tool(message: str) -> str: ] ) final_response = ChatResponse(messages=[Message(role="assistant", contents=["Final response"])]) - chat_client_base.responses = [function_call_response, final_response] + chat_client_base.responses = [function_call_response, final_response] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Should work - relies on type annotations agent = Agent( @@ -1605,7 +1605,7 @@ async def test_insufficient_parameters_error(self, client: Any) -> None: # Should raise MiddlewareException about insufficient parameters with pytest.raises(MiddlewareException, match="must have at least 2 parameters"): - @agent_middleware # type: ignore[arg-type] + @agent_middleware # type: ignore[arg-type] # ty: ignore[invalid-argument-type] async def insufficient_params_middleware(context: Any) -> None: # Missing 'next' parameter pass @@ -1676,7 +1676,7 @@ async def process(self, context: AgentContext, call_next: Callable[[], Awaitable # Create Agent with session tracking middleware middleware = SessionTrackingMiddleware() - agent = Agent(client=client, middleware=[middleware]) + agent = Agent(client=client, middleware=[middleware]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Create a session that will persist messages between runs session = agent.create_session() @@ -1813,7 +1813,7 @@ async def message_modifier_middleware(context: ChatContext, call_next: Callable[ if msg.role == "system": continue original_text = msg.text or "" - context.messages[idx] = Message(role=msg.role, contents=[f"MODIFIED: {original_text}"]) + context.messages[idx] = Message(role=msg.role, contents=[f"MODIFIED: {original_text}"]) # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[invalid-assignment] break await call_next() @@ -1839,7 +1839,7 @@ async def response_override_middleware(context: ChatContext, call_next: Callable messages=[Message(role="assistant", contents=["MiddlewareTypes overridden response"])], response_id="middleware-response-123", ) - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Create Agent with response-overriding middleware client = MockBaseChatClient() @@ -1968,7 +1968,7 @@ async def process(self, context: ChatContext, call_next: Callable[[], Awaitable[ execution_order.append("middleware_before") await call_next() execution_order.append("middleware_after") - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Create Agent with terminating middleware client = MockBaseChatClient() @@ -2219,9 +2219,9 @@ async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Await assert isinstance(context.options, dict) captured_options.update(context.options) - context.options["temperature"] = 0.9 - context.options["max_tokens"] = 500 - context.options["new_param"] = "added_by_middleware" + context.options["temperature"] = 0.9 # ty: ignore[invalid-assignment] + context.options["max_tokens"] = 500 # ty: ignore[invalid-assignment] + context.options["new_param"] = "added_by_middleware" # ty: ignore[invalid-assignment] modified_options.update(context.options) @@ -2233,9 +2233,9 @@ async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Await # Execute the agent with runtime options messages = [Message(role="user", contents=["test message"])] - response = await agent.run( + response = await agent.run( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] messages, - options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"}, + options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"}, # type: ignore[typeddict-unknown-key] ) # Verify response diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 14243c65786..c4aaffc7a88 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, cast from unittest.mock import patch from agent_framework import ( Agent, ChatContext, ChatMiddleware, + ChatMiddlewareTypes, ChatResponse, ChatResponseUpdate, Content, @@ -40,7 +41,7 @@ async def process( execution_order.append("chat_middleware_after") # Add middleware to chat client - chat_client_base.chat_middleware = [LoggingChatMiddleware()] + chat_client_base.chat_middleware = [LoggingChatMiddleware()] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Execute chat client directly messages = [Message(role="user", contents=["test message"])] @@ -65,7 +66,7 @@ async def logging_chat_middleware(context: ChatContext, call_next: Callable[[], execution_order.append("function_middleware_after") # Add middleware to chat client - chat_client_base.chat_middleware = [logging_chat_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, logging_chat_middleware)] # Execute chat client directly messages = [Message(role="user", contents=["test message"])] @@ -87,11 +88,11 @@ async def message_modifier_middleware(context: ChatContext, call_next: Callable[ # Modify the first message by adding a prefix if context.messages and len(context.messages) > 0: original_text = context.messages[0].text or "" - context.messages[0] = Message(role=context.messages[0].role, contents=[f"MODIFIED: {original_text}"]) + context.messages[0] = Message(role=context.messages[0].role, contents=[f"MODIFIED: {original_text}"]) # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[invalid-assignment] await call_next() # Add middleware to chat client - chat_client_base.chat_middleware = [message_modifier_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, message_modifier_middleware)] # Execute chat client messages = [Message(role="user", contents=["test message"])] @@ -113,10 +114,10 @@ async def response_override_middleware(context: ChatContext, call_next: Callable messages=[Message(role="assistant", contents=["MiddlewareTypes overridden response"])], response_id="middleware-response-123", ) - context.terminate = True + context.terminate = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Add middleware to chat client - chat_client_base.chat_middleware = [response_override_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, response_override_middleware)] # Execute chat client messages = [Message(role="user", contents=["test message"])] @@ -145,7 +146,10 @@ async def second_middleware(context: ChatContext, call_next: Callable[[], Awaita execution_order.append("second_after") # Add middleware to chat client (order should be preserved) - chat_client_base.chat_middleware = [first_middleware, second_middleware] + chat_client_base.chat_middleware = [ + cast(ChatMiddlewareTypes, first_middleware), + cast(ChatMiddlewareTypes, second_middleware), + ] # Execute chat client messages = [Message(role="user", contents=["test message"])] @@ -241,7 +245,7 @@ async def streaming_middleware(context: ChatContext, call_next: Callable[[], Awa def upper_case_update(update: ChatResponseUpdate) -> ChatResponseUpdate: for content in update.contents: if content.type == "text": - content.text = content.text.upper() + content.text = content.text.upper() # type: ignore[union-attr] # ty: ignore[unresolved-attribute] return update context.stream_transform_hooks.append(upper_case_update) @@ -249,7 +253,7 @@ def upper_case_update(update: ChatResponseUpdate) -> ChatResponseUpdate: execution_order.append("streaming_after") # Add middleware to chat client - chat_client_base.chat_middleware = [streaming_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, streaming_middleware)] # Execute streaming response messages = [Message(role="user", contents=["test message"])] @@ -259,7 +263,7 @@ def upper_case_update(update: ChatResponseUpdate) -> ChatResponseUpdate: # Verify we got updates assert len(updates) > 0 - assert all(update.text == update.text.upper() for update in updates) + assert all(update.text == update.text.upper() for update in updates) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Verify middleware executed assert execution_order == ["streaming_before", "streaming_after"] @@ -338,22 +342,22 @@ async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaita assert isinstance(context.options, dict) captured_options.update(context.options) - context.options["temperature"] = 0.9 - context.options["max_tokens"] = 500 - context.options["new_param"] = "added_by_middleware" + context.options["temperature"] = 0.9 # ty: ignore[invalid-assignment] + context.options["max_tokens"] = 500 # ty: ignore[invalid-assignment] + context.options["new_param"] = "added_by_middleware" # ty: ignore[invalid-assignment] modified_options.update(context.options) await call_next() # Add middleware to chat client - chat_client_base.chat_middleware = [kwargs_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, kwargs_middleware)] # Execute chat client with runtime options messages = [Message(role="user", contents=["test message"])] - response = await chat_client_base.get_response( + response = await chat_client_base.get_response( # type: ignore[call-overload, var-annotated] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] messages, - options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"}, + options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"}, # type: ignore[typeddict-unknown-key] ) # Verify response @@ -408,7 +412,7 @@ async def runtime_middleware(context: ChatContext, call_next: Callable[[], Await pipeline_no_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) # With base middleware - chat_client_base.chat_middleware = [base_middleware] + chat_client_base.chat_middleware = [cast(ChatMiddlewareTypes, base_middleware)] pipeline_with_base = chat_client_base._get_chat_middleware_pipeline([runtime_middleware]) assert pipeline_with_base is not pipeline_no_base diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 46f6e2c1517..a4317c20009 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from typing import Any from unittest.mock import Mock, patch @@ -113,7 +113,7 @@ def test_start_span_basic(span_exporter: InMemorySpanExporter): OtelAttr.TOOL_TYPE: "function", } span_exporter.clear() - with get_function_span(attributes) as function_span: + with get_function_span(attributes) as function_span: # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert function_span is not None function_span.set_attribute("test_attr", "test_value") @@ -121,10 +121,10 @@ def test_start_span_basic(span_exporter: InMemorySpanExporter): assert len(spans) == 1 span = spans[0] assert span.name == "execute_tool test_function" - assert span.attributes["test_attr"] == "test_value" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.TOOL_EXECUTION_OPERATION - assert span.attributes[OtelAttr.TOOL_NAME] == "test_function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "Test function description" + assert span.attributes["test_attr"] == "test_value" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.TOOL_EXECUTION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_NAME] == "test_function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "Test function description" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_start_span_with_tool_call_id(span_exporter: InMemorySpanExporter): @@ -140,20 +140,20 @@ def test_start_span_with_tool_call_id(span_exporter: InMemorySpanExporter): } span_exporter.clear() - with get_function_span(attributes) as function_span: + with get_function_span(attributes) as function_span: # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert function_span is not None function_span.set_attribute("test_attr", "test_value") spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] assert span.name == "execute_tool test_function" - assert span.attributes["test_attr"] == "test_value" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == tool_call_id + assert span.attributes["test_attr"] == "test_value" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == tool_call_id # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # Verify all attributes - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.TOOL_EXECUTION_OPERATION - assert span.attributes[OtelAttr.TOOL_NAME] == "test_function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "Test function" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.TOOL_EXECUTION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_NAME] == "test_function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "Test function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.fixture @@ -164,8 +164,13 @@ class MockChatClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: if stream: return self._get_streaming_response(messages=messages, options=options, **kwargs) @@ -176,7 +181,7 @@ async def _get() -> ChatResponse: return _get() async def _get_non_streaming_response( - self, *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any + self, *, messages: Sequence[Message], options: Mapping[str, Any], **kwargs: Any ) -> ChatResponse: return ChatResponse( messages=[Message("assistant", ["Test response"])], @@ -185,7 +190,7 @@ async def _get_non_streaming_response( ) def _get_streaming_response( - self, *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any + self, *, messages: Sequence[Message], options: Mapping[str, Any], **kwargs: Any ) -> ResponseStream[ChatResponseUpdate, ChatResponse]: async def _stream() -> AsyncIterable[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[Content.from_text("Hello")], role="assistant") @@ -212,13 +217,13 @@ async def test_chat_client_observability(mock_chat_client, span_exporter: InMemo assert len(spans) == 1 span = spans[0] assert span.name == "chat Test" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION - assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" - assert span.attributes[OtelAttr.INPUT_TOKENS] == 10 - assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 20 + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.INPUT_TOKENS] == 10 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 20 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] if enable_sensitive_data: - assert span.attributes[OtelAttr.INPUT_MESSAGES] is not None - assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None + assert span.attributes[OtelAttr.INPUT_MESSAGES] is not None # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) @@ -235,7 +240,7 @@ async def test_chat_client_observability_accepts_model_option( spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" + assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) @@ -259,11 +264,11 @@ async def test_chat_client_streaming_observability( assert len(spans) == 1 span = spans[0] assert span.name == "chat Test" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION - assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "Test" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] if enable_sensitive_data: - assert span.attributes[OtelAttr.INPUT_MESSAGES] is not None - assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None + assert span.attributes[OtelAttr.INPUT_MESSAGES] is not None # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) @@ -286,13 +291,13 @@ async def test_chat_client_observability_with_instructions( span = spans[0] # Verify system_instructions attribute is set - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert system_instructions[0]["content"] == "You are a helpful assistant." # Verify input_messages excludes system instructions - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["user"] @@ -320,12 +325,12 @@ async def test_chat_client_streaming_observability_with_instructions( span = spans[0] # Verify system_instructions attribute is set - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert system_instructions[0]["content"] == "You are a helpful assistant." - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["user"] @@ -351,10 +356,10 @@ async def test_chat_client_observability_with_system_message_and_instructions( assert len(spans) == 1 span = spans[0] - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert system_instructions == [{"type": "text", "content": "Framework system instruction"}] - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["system", "user"] assert input_messages[0]["parts"][0]["content"] == "Original system message" assert input_messages[1]["parts"][0]["content"] == "Test message" @@ -378,7 +383,7 @@ async def test_chat_client_observability_without_instructions( span = spans[0] # Verify system_instructions attribute is NOT set - assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes + assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] @pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) @@ -399,7 +404,7 @@ async def test_chat_client_observability_with_empty_instructions( span = spans[0] # Empty string should not set system_instructions - assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes + assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] @pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) @@ -422,8 +427,8 @@ async def test_chat_client_observability_with_list_instructions( span = spans[0] # Verify system_instructions attribute contains both instructions - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 2 assert system_instructions[0]["content"] == "Instruction 1" assert system_instructions[1]["content"] == "Instruction 2" @@ -442,8 +447,8 @@ async def test_chat_client_without_model_observability(mock_chat_client, span_ex span = spans[0] assert span.name == "chat unknown" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION - assert span.attributes[OtelAttr.REQUEST_MODEL] == "unknown" + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "unknown" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_chat_client_streaming_without_model_observability(mock_chat_client, span_exporter: InMemorySpanExporter): @@ -464,8 +469,8 @@ async def test_chat_client_streaming_without_model_observability(mock_chat_clien assert len(spans) == 1 span = spans[0] assert span.name == "chat unknown" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION - assert span.attributes[OtelAttr.REQUEST_MODEL] == "unknown" + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "unknown" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_prepend_user_agent_with_none_value(): @@ -515,7 +520,7 @@ async def _stream(): finalizer=AgentResponse.from_updates, ) - class MockChatClientAgent(AgentTelemetryLayer, _MockChatClientAgent): + class MockChatClientAgent(AgentTelemetryLayer, _MockChatClientAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass return MockChatClientAgent @@ -527,7 +532,7 @@ async def test_agent_span_captures_response_telemetry_without_inner_chat_span( ): """Agent spans should retain response telemetry when no inner chat span owns it.""" - agent = mock_chat_agent() + agent = mock_chat_agent() # type: ignore[operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] span_exporter.clear() response = await agent.run("Test message") @@ -536,16 +541,16 @@ async def test_agent_span_captures_response_telemetry_without_inner_chat_span( assert len(spans) == 1 span = spans[0] assert span.name == "invoke_agent test_agent" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION - assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" - assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" - assert span.attributes[OtelAttr.REQUEST_MODEL] == "TestModel" - assert span.attributes[OtelAttr.RESPONSE_ID] == "test_response_id" - assert span.attributes[OtelAttr.INPUT_TOKENS] == 15 - assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 25 + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "TestModel" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.RESPONSE_ID] == "test_response_id" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.INPUT_TOKENS] == 15 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 25 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] if enable_sensitive_data: - assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None + assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) @@ -553,7 +558,7 @@ async def test_agent_streaming_response_with_diagnostics_enabled( mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data ): """Test agent streaming telemetry through the agent telemetry mixin.""" - agent = mock_chat_agent() + agent = mock_chat_agent() # type: ignore[operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] span_exporter.clear() updates = [] stream = agent.run("Test message", stream=True) @@ -567,13 +572,13 @@ async def test_agent_streaming_response_with_diagnostics_enabled( assert len(spans) == 1 span = spans[0] assert span.name == "invoke_agent test_agent" - assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION - assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" - assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" - assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" - assert span.attributes[OtelAttr.REQUEST_MODEL] == "TestModel" + assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.AGENT_INVOKE_OPERATION # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.REQUEST_MODEL] == "TestModel" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] if enable_sensitive_data: - assert span.attributes.get(OtelAttr.OUTPUT_MESSAGES) is not None # Streaming, so no usage yet + assert span.attributes.get(OtelAttr.OUTPUT_MESSAGES) is not None # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Streaming, so no usage yet async def test_function_call_with_error_handling(span_exporter: InMemorySpanExporter): @@ -1008,7 +1013,7 @@ def test_create_otlp_exporters_grpc_missing_dependency(): ) def test_configure_otel_providers_with_views(monkeypatch): """Test configure_otel_providers accepts views parameter.""" - from opentelemetry.sdk.metrics import View + from opentelemetry.sdk.metrics import View # type: ignore[attr-defined] # ty: ignore[unresolved-import] from opentelemetry.sdk.metrics.view import DropAggregation from agent_framework.observability import configure_otel_providers @@ -1023,7 +1028,7 @@ def test_configure_otel_providers_with_views(monkeypatch): monkeypatch.delenv(key, raising=False) # Create a view that drops all metrics - views = [View(instrument_name="*", aggregation=DropAggregation())] + views = [View(instrument_name="*", aggregation=DropAggregation())] # pyrefly: ignore[not-callable] # Should not raise an error configure_otel_providers(views=views) @@ -1780,9 +1785,9 @@ def test_to_otel_part_data(): content = Content.from_data(data=data, media_type="application/octet-stream") result = _to_otel_part(content) - assert result["type"] == "blob" - assert result["mime_type"] == "application/octet-stream" - assert result["modality"] == "application" + assert result["type"] == "blob" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert result["mime_type"] == "application/octet-stream" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert result["modality"] == "application" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_to_otel_part_function_call(): @@ -1868,8 +1873,8 @@ def test_to_otel_part_function_result(): content = Content(type="function_result", call_id="call_123", result="Success") result = _to_otel_part(content) - assert result["type"] == "tool_call_response" - assert result["id"] == "call_123" + assert result["type"] == "tool_call_response" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert result["id"] == "call_123" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # region Test workflow observability functions @@ -1894,11 +1899,11 @@ def test_create_workflow_span(span_exporter): """Test create_workflow_span creates a span.""" from agent_framework.observability import create_workflow_span - span_exporter.clear() + span_exporter.clear() # type: ignore[attr-defined] with create_workflow_span("test_workflow", attributes={"key": "value"}): pass - spans = span_exporter.get_finished_spans() + spans = span_exporter.get_finished_spans() # type: ignore[attr-defined] assert len(spans) == 1 assert spans[0].name == "test_workflow" assert spans[0].attributes["key"] == "value" @@ -1908,7 +1913,7 @@ def test_create_processing_span(span_exporter): """Test create_processing_span creates a span with correct attributes.""" from agent_framework.observability import OtelAttr, create_processing_span - span_exporter.clear() + span_exporter.clear() # type: ignore[attr-defined] with create_processing_span( executor_id="exec_1", executor_type="TestExecutor", @@ -1917,7 +1922,7 @@ def test_create_processing_span(span_exporter): ): pass - spans = span_exporter.get_finished_spans() + spans = span_exporter.get_finished_spans() # type: ignore[attr-defined] assert len(spans) == 1 assert OtelAttr.EXECUTOR_PROCESS_SPAN in spans[0].name assert spans[0].attributes[OtelAttr.EXECUTOR_ID] == "exec_1" @@ -1928,7 +1933,7 @@ def test_create_edge_group_processing_span(span_exporter): """Test create_edge_group_processing_span creates correct span.""" from agent_framework.observability import OtelAttr, create_edge_group_processing_span - span_exporter.clear() + span_exporter.clear() # type: ignore[attr-defined] with create_edge_group_processing_span( edge_group_type="ConditionalEdge", edge_group_id="edge_1", @@ -1937,7 +1942,7 @@ def test_create_edge_group_processing_span(span_exporter): ): pass - spans = span_exporter.get_finished_spans() + spans = span_exporter.get_finished_spans() # type: ignore[attr-defined] assert len(spans) == 1 assert OtelAttr.EDGE_GROUP_PROCESS_SPAN in spans[0].name assert spans[0].attributes[OtelAttr.EDGE_GROUP_TYPE] == "ConditionalEdge" @@ -1950,7 +1955,7 @@ def test_create_edge_group_processing_span_invalid_link(span_exporter): """Test create_edge_group_processing_span handles invalid trace context gracefully.""" from agent_framework.observability import create_edge_group_processing_span - span_exporter.clear() + span_exporter.clear() # type: ignore[attr-defined] # Invalid trace context should be handled gracefully trace_contexts = [{"traceparent": "invalid-format"}] span_ids = ["invalid"] @@ -1962,7 +1967,7 @@ def test_create_edge_group_processing_span_invalid_link(span_exporter): ): pass - spans = span_exporter.get_finished_spans() + spans = span_exporter.get_finished_spans() # type: ignore[attr-defined] assert len(spans) == 1 # Should still create the span @@ -2092,7 +2097,7 @@ def test_get_response_attributes_with_response_id(): response.raw_representation = None response.usage_details = None - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response) assert result[OtelAttr.RESPONSE_ID] == "resp_123" @@ -2110,7 +2115,7 @@ def test_get_response_attributes_with_finish_reason(): response.raw_representation = None response.usage_details = None - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response) assert OtelAttr.FINISH_REASONS in result @@ -2129,7 +2134,7 @@ def test_get_response_attributes_with_model(): response.usage_details = None response.model = "gpt-4" - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response) assert result[OtelAttr.RESPONSE_MODEL] == "gpt-4" @@ -2147,7 +2152,7 @@ def test_get_response_attributes_with_usage(): response.raw_representation = None response.usage_details = {"input_token_count": 100, "output_token_count": 50} - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response) assert result[OtelAttr.INPUT_TOKENS] == 100 @@ -2166,7 +2171,7 @@ def test_get_response_attributes_capture_usage_false(): response.raw_representation = None response.usage_details = {"input_token_count": 100, "output_token_count": 50} - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response, capture_usage=False) assert OtelAttr.INPUT_TOKENS not in result @@ -2185,7 +2190,7 @@ def test_get_response_attributes_capture_response_id_false(): response.raw_representation = None response.usage_details = None - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response, capture_response_id=False) assert OtelAttr.RESPONSE_ID not in result @@ -2259,7 +2264,7 @@ def test_to_otel_part_generic(): from agent_framework.observability import _to_otel_part # Create a content with type that falls to default case - content = Content(type="annotations", text="some text") + content = Content(type="annotations", text="some text") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] result = _to_otel_part(content) # Should return result from to_dict @@ -2285,7 +2290,7 @@ def test_get_response_attributes_finish_reason_from_raw(): response.raw_representation = raw_rep response.usage_details = None - attrs = {} + attrs = {} # type: ignore[var-annotated] result = _get_response_attributes(attrs, response) assert OtelAttr.FINISH_REASONS in result @@ -2305,7 +2310,7 @@ def __init__(self): self._id = "test_agent" self._name = "Test Agent" self._description = "A test agent" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2349,7 +2354,7 @@ async def _run_stream( yield AgentResponseUpdate(contents=[Content.from_text("Test")], role="assistant") - class MockAgent(AgentTelemetryLayer, _MockAgent): + class MockAgent(AgentTelemetryLayer, _MockAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = MockAgent() @@ -2373,7 +2378,7 @@ def __init__(self): self._id = "failing_agent" self._name = "Failing Agent" self._description = "An agent that fails" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2394,7 +2399,7 @@ def default_options(self): async def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): raise RuntimeError("Agent failed") - class FailingAgent(AgentTelemetryLayer, _FailingAgent): + class FailingAgent(AgentTelemetryLayer, _FailingAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = FailingAgent() @@ -2423,7 +2428,7 @@ def __init__(self): self._id = "streaming_agent" self._name = "Streaming Agent" self._description = "A streaming test agent" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2459,7 +2464,7 @@ async def _stream(): finalizer=AgentResponse.from_updates, ) - class StreamingAgent(AgentTelemetryLayer, _StreamingAgent): + class StreamingAgent(AgentTelemetryLayer, _StreamingAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = StreamingAgent() @@ -2549,7 +2554,7 @@ async def _inner_get_response(self, *, messages, options, **kwargs): span = spans[0] # Check output messages include finish_reason - output_messages = json.loads(span.attributes[OtelAttr.OUTPUT_MESSAGES]) + output_messages = json.loads(span.attributes[OtelAttr.OUTPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert output_messages[-1].get("finish_reason") == "stop" @@ -2568,7 +2573,7 @@ def __init__(self): self._id = "failing_stream" self._name = "Failing Stream" self._description = "A failing streaming agent" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2604,7 +2609,7 @@ async def _stream(): finalizer=AgentResponse.from_updates, ) - class FailingStreamingAgent(AgentTelemetryLayer, _FailingStreamingAgent): + class FailingStreamingAgent(AgentTelemetryLayer, _FailingStreamingAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = FailingStreamingAgent() @@ -2665,7 +2670,7 @@ def __init__(self): self._id = "test" self._name = "Test" self._description = "Test" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2685,9 +2690,9 @@ def default_options(self): async def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): if stream: - return ResponseStream( + return ResponseStream( # type: ignore[call-arg, misc] self._run_stream(messages=messages, **kwargs), - lambda x: AgentResponse.from_updates(x), + finalizer=lambda x: AgentResponse.from_updates(updates=x), ) return AgentResponse(messages=[]) @@ -2696,7 +2701,7 @@ async def _run_stream(self, messages=None, *, session=None, **kwargs): yield AgentResponseUpdate(contents=[Content.from_text("test")], role="assistant") - class TestAgent(AgentTelemetryLayer, _TestAgent): + class TestAgent(AgentTelemetryLayer, _TestAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = TestAgent() @@ -2720,7 +2725,7 @@ def __init__(self): self._id = "test" self._name = "Test" self._description = "Test" - self._default_options = {} + self._default_options = {} # type: ignore[var-annotated] @property def id(self): @@ -2749,7 +2754,7 @@ async def _run(self, messages=None, *, session=None, **kwargs): async def _run_stream(self, messages=None, *, session=None, **kwargs): yield AgentResponseUpdate(contents=[Content.from_text("test")], role="assistant") - class TestAgent(AgentTelemetryLayer, _TestAgent): + class TestAgent(AgentTelemetryLayer, _TestAgent): # type: ignore[misc] # pyrefly: ignore[inconsistent-inheritance] pass agent = TestAgent() @@ -2852,7 +2857,7 @@ def test_get_span_creates_span(span_exporter: InMemorySpanExporter): OtelAttr.TOOL_NAME: "test_tool", } - with _get_span(attributes=attributes, span_name_attribute=OtelAttr.TOOL_NAME): + with _get_span(attributes=attributes, span_name_attribute=OtelAttr.TOOL_NAME): # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] pass spans = span_exporter.get_finished_spans() @@ -2929,8 +2934,8 @@ def test_capture_response(span_exporter: InMemorySpanExporter): spans = span_exporter.get_finished_spans() assert len(spans) == 1 # Verify attributes were set on the span - assert spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100 - assert spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50 + assert spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] async def test_layer_ordering_span_sequence_with_function_calling(span_exporter: InMemorySpanExporter): @@ -2971,8 +2976,13 @@ def __init__(self): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: async def _get() -> ChatResponse: self.call_count += 1 @@ -3022,8 +3032,8 @@ async def _get() -> ChatResponse: assert sorted_spans[1].name.startswith("execute_tool"), ( f"Second span should be 'execute_tool', got '{sorted_spans[1].name}'" ) - assert sorted_spans[1].attributes.get(OtelAttr.TOOL_NAME) == "get_weather" - assert sorted_spans[1].attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION + assert sorted_spans[1].attributes.get(OtelAttr.TOOL_NAME) == "get_weather" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert sorted_spans[1].attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Third span: second chat (LLM call with function result) assert sorted_spans[2].name.startswith("chat"), f"Third span should be 'chat', got '{sorted_spans[2].name}'" @@ -3039,8 +3049,13 @@ class NestedTelemetryChatClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: if stream: @@ -3069,11 +3084,11 @@ async def _get() -> ChatResponse: return _get() agent = Agent( - client=NestedTelemetryChatClient(), + client=NestedTelemetryChatClient(), # ty: ignore[invalid-argument-type] id="nested_agent_id", name="nested_agent", description="Nested telemetry agent", - default_options={"model": "NestedModel"}, + default_options={"model": "NestedModel"}, # pyrefly: ignore[bad-argument-type] ) span_exporter.clear() @@ -3091,18 +3106,18 @@ async def _get() -> ChatResponse: spans = span_exporter.get_finished_spans() assert len(spans) == 2 - span_by_operation = {span.attributes[OtelAttr.OPERATION.value]: span for span in spans} + span_by_operation = {span.attributes[OtelAttr.OPERATION.value]: span for span in spans} # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] agent_span = span_by_operation[OtelAttr.AGENT_INVOKE_OPERATION] chat_span = span_by_operation[OtelAttr.CHAT_COMPLETION_OPERATION] - assert chat_span.attributes[OtelAttr.RESPONSE_ID] == "nested_resp_123" - assert chat_span.attributes[OtelAttr.INPUT_TOKENS] == 11 - assert chat_span.attributes[OtelAttr.OUTPUT_TOKENS] == 22 + assert chat_span.attributes[OtelAttr.RESPONSE_ID] == "nested_resp_123" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert chat_span.attributes[OtelAttr.INPUT_TOKENS] == 11 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert chat_span.attributes[OtelAttr.OUTPUT_TOKENS] == 22 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] - assert OtelAttr.RESPONSE_ID not in agent_span.attributes + assert OtelAttr.RESPONSE_ID not in agent_span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # The agent span carries the aggregated usage from all inner chat completions - assert agent_span.attributes[OtelAttr.INPUT_TOKENS] == 11 - assert agent_span.attributes[OtelAttr.OUTPUT_TOKENS] == 22 + assert agent_span.attributes[OtelAttr.INPUT_TOKENS] == 11 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert agent_span.attributes[OtelAttr.OUTPUT_TOKENS] == 22 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # region Test non-ASCII character handling in JSON serialization @@ -3134,20 +3149,20 @@ async def _inner_get_response(self, *, messages, options, **kwargs): span = spans[0] # Verify input messages preserve Japanese characters - input_messages_json = span.attributes[OtelAttr.INPUT_MESSAGES] - assert japanese_text in input_messages_json + input_messages_json = span.attributes[OtelAttr.INPUT_MESSAGES] # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert japanese_text in input_messages_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Ensure it's not escaped to Unicode - assert "\\u" not in input_messages_json + assert "\\u" not in input_messages_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Verify output messages preserve Japanese characters - output_messages_json = span.attributes[OtelAttr.OUTPUT_MESSAGES] - assert japanese_text in output_messages_json - assert "\\u" not in output_messages_json + output_messages_json = span.attributes[OtelAttr.OUTPUT_MESSAGES] # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert japanese_text in output_messages_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "\\u" not in output_messages_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Verify JSON is valid and contains the text - input_messages = json.loads(input_messages_json) + input_messages = json.loads(input_messages_json) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert input_messages[0]["parts"][0]["content"] == japanese_text - output_messages = json.loads(output_messages_json) + output_messages = json.loads(output_messages_json) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert output_messages[0]["parts"][0]["content"] == japanese_text @@ -3173,10 +3188,11 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter: spans = span_exporter.get_finished_spans() assert len(spans) == 1 - span = spans[0] + span = spans[0] # type: ignore[assignment] # Verify system instructions preserve Chinese characters - system_instructions_json = span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS] + system_instructions_json = span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS] # type: ignore[attr-defined] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert isinstance(system_instructions_json, str) assert chinese_text in system_instructions_json assert "\\u" not in system_instructions_json @@ -3184,7 +3200,7 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter: system_instructions = json.loads(system_instructions_json) assert system_instructions[0]["content"] == chinese_text - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[attr-defined] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["user"] @@ -3222,8 +3238,8 @@ class HandoffRequest: _capture_messages(span=span, provider_name="test_provider", messages=[msg]) spans = span_exporter.get_finished_spans() - span = spans[0] - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + span = spans[0] # type: ignore[assignment] + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[attr-defined] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] tool_part = input_messages[0]["parts"][0] assert tool_part["type"] == "tool_call" assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"} @@ -3253,7 +3269,7 @@ def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_mess spans = span_exporter.get_finished_spans() assert len(spans) == 1 - input_messages = json.loads(spans[0].attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(spans[0].attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["user"] assert mock_logger_info.call_count == 1, f"Expected 1 log call, got {mock_logger_info.call_count}" @@ -3291,7 +3307,7 @@ def test_capture_messages_logs_only_chat_history_when_framework_instructions_are spans = span_exporter.get_finished_spans() assert len(spans) == 1 - input_messages = json.loads(spans[0].attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(spans[0].attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["system", "user"] assert mock_logger_info.call_count == 2, f"Expected 2 log calls, got {mock_logger_info.call_count}" @@ -3321,12 +3337,12 @@ def greet(message: str) -> str: span = spans[0] # Verify tool arguments preserve Korean characters - tool_arguments_json = span.attributes[OtelAttr.TOOL_ARGUMENTS] - assert korean_text in tool_arguments_json - assert "\\u" not in tool_arguments_json + tool_arguments_json = span.attributes[OtelAttr.TOOL_ARGUMENTS] # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert korean_text in tool_arguments_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "\\u" not in tool_arguments_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Verify JSON is valid and contains the text - tool_arguments = json.loads(tool_arguments_json) + tool_arguments = json.loads(tool_arguments_json) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert tool_arguments["message"] == korean_text @@ -3350,8 +3366,8 @@ def echo(text: str) -> str: span = spans[0] # Verify tool result preserves Arabic characters - tool_result = span.attributes[OtelAttr.TOOL_RESULT] - assert arabic_text in tool_result + tool_result = span.attributes[OtelAttr.TOOL_RESULT] # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert arabic_text in tool_result # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] @pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) @@ -3379,19 +3395,19 @@ def greet_with_model(greeting: Greeting) -> str: span_exporter.clear() # Use the tool's input_model to properly pass the Pydantic model argument input_model = greet_with_model.input_model - await greet_with_model.invoke(arguments=input_model(greeting=Greeting(message=japanese_text))) + await greet_with_model.invoke(arguments=input_model(greeting=Greeting(message=japanese_text))) # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] # Verify tool arguments preserve Japanese characters - tool_arguments_json = span.attributes[OtelAttr.TOOL_ARGUMENTS] - assert japanese_text in tool_arguments_json - assert "\\u" not in tool_arguments_json + tool_arguments_json = span.attributes[OtelAttr.TOOL_ARGUMENTS] # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert japanese_text in tool_arguments_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "\\u" not in tool_arguments_json # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Verify JSON is valid and contains the text - tool_arguments = json.loads(tool_arguments_json) + tool_arguments = json.loads(tool_arguments_json) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert tool_arguments["greeting"]["message"] == japanese_text @@ -3418,12 +3434,12 @@ async def test_agent_instructions_from_default_options( span = spans[0] # Instructions from default_options should be captured - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert system_instructions[0]["content"] == "Default system instructions." - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["user"] @@ -3449,10 +3465,10 @@ async def test_agent_instructions_preserve_system_messages_in_history( assert len(spans) == 1 span = spans[0] - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert system_instructions == [{"type": "text", "content": "Default system instructions."}] - input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) + input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert [msg.get("role") for msg in input_messages] == ["system", "user"] assert input_messages[0]["parts"][0]["content"] == "Original system message" assert input_messages[1]["parts"][0]["content"] == "Test message" @@ -3477,8 +3493,8 @@ async def test_agent_instructions_from_options_override( assert len(spans) == 1 span = spans[0] - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert system_instructions[0]["content"] == "Override instructions." @@ -3503,8 +3519,8 @@ async def test_agent_instructions_merged_from_default_and_options( span = spans[0] # Merged instructions should contain both default and override, concatenated with newline - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert "Default instructions." in system_instructions[0]["content"] assert "Additional instructions." in system_instructions[0]["content"] @@ -3533,8 +3549,8 @@ async def test_agent_streaming_instructions_from_default_options( assert len(spans) == 1 span = spans[0] - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert system_instructions[0]["content"] == "Default streaming instructions." @@ -3562,8 +3578,8 @@ async def test_agent_streaming_instructions_merged_from_default_and_options( assert len(spans) == 1 span = spans[0] - assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes - system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) + assert OtelAttr.SYSTEM_INSTRUCTIONS in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + system_instructions = json.loads(span.attributes[OtelAttr.SYSTEM_INSTRUCTIONS]) # type: ignore[arg-type, index] # pyrefly: ignore[bad-argument-type, unsupported-operation] # ty: ignore[invalid-argument-type, not-subscriptable] assert len(system_instructions) == 1 assert "Default instructions." in system_instructions[0]["content"] assert "Stream override." in system_instructions[0]["content"] @@ -3586,7 +3602,7 @@ async def test_agent_no_instructions_in_default_or_options( assert len(spans) == 1 span = spans[0] - assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes + assert OtelAttr.SYSTEM_INSTRUCTIONS not in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # region Additional coverage tests @@ -3656,7 +3672,7 @@ def test_capture_response_with_error_type(span_exporter: InMemorySpanExporter): spans = span_exporter.get_finished_spans() assert len(spans) == 1 - assert spans[0].attributes.get(OtelAttr.ERROR_TYPE) == "ValueError" + assert spans[0].attributes.get(OtelAttr.ERROR_TYPE) == "ValueError" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] def test_backfill_request_model_when_unknown(span_exporter: InMemorySpanExporter): @@ -3731,8 +3747,13 @@ class BackfillingChatClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: async def _get() -> ChatResponse: return ChatResponse( @@ -3751,8 +3772,8 @@ async def _get() -> ChatResponse: assert len(spans) == 1 span = spans[0] assert span.name == "chat resolved-model" - assert span.attributes[OtelAttr.REQUEST_MODEL] == "resolved-model" - assert span.attributes[OtelAttr.RESPONSE_MODEL] == "resolved-model" + assert span.attributes[OtelAttr.REQUEST_MODEL] == "resolved-model" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.RESPONSE_MODEL] == "resolved-model" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_chat_client_streaming_backfills_request_model_from_response( @@ -3764,8 +3785,13 @@ class BackfillingStreamingChatClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: async def _stream() -> AsyncIterable[ChatResponseUpdate]: yield ChatResponseUpdate(contents=[Content.from_text("Hello")], role="assistant") @@ -3789,8 +3815,8 @@ def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: assert len(spans) == 1 span = spans[0] assert span.name == "chat resolved-stream-model" - assert span.attributes[OtelAttr.REQUEST_MODEL] == "resolved-stream-model" - assert span.attributes[OtelAttr.RESPONSE_MODEL] == "resolved-stream-model" + assert span.attributes[OtelAttr.REQUEST_MODEL] == "resolved-stream-model" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.RESPONSE_MODEL] == "resolved-stream-model" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_configure_otel_providers_with_env_file_path(monkeypatch, tmp_path): @@ -3955,22 +3981,22 @@ class _InstrumentedAgent(AgentTelemetryLayer, RawAgent): spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(invoke_spans) == 1 agent_span = invoke_spans[0] - chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(chat_spans) == 2 # Individual chat spans retain their own usage - assert chat_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 2239 - assert chat_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 - assert chat_spans[1].attributes.get(OtelAttr.INPUT_TOKENS) == 2569 - assert chat_spans[1].attributes.get(OtelAttr.OUTPUT_TOKENS) == 99 + assert chat_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 2239 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert chat_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert chat_spans[1].attributes.get(OtelAttr.INPUT_TOKENS) == 2569 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert chat_spans[1].attributes.get(OtelAttr.OUTPUT_TOKENS) == 99 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # The invoke_agent span must report the aggregate across all LLM round-trips - assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 2239 + 2569 - assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 + 99 + assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 2239 + 2569 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 + 99 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] @pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True) @@ -3995,11 +4021,11 @@ class _InstrumentedAgent(AgentTelemetryLayer, RawAgent): await agent.run(messages="Hi") spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(invoke_spans) == 1 - assert invoke_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100 - assert invoke_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50 + assert invoke_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert invoke_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] @pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True) @@ -4042,13 +4068,13 @@ class _InstrumentedAgent(AgentTelemetryLayer, RawAgent): spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(invoke_spans) == 1 agent_span = invoke_spans[0] # The invoke_agent span must aggregate usage from the in-loop call and the final exhaustion call - assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 500 - assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 100 + assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 500 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 100 # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # region Test span nesting (parent-child relationships) @@ -4062,8 +4088,13 @@ class NestedChatClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: if stream: @@ -4094,10 +4125,10 @@ async def _get() -> ChatResponse: return _get() agent = Agent( - client=NestedChatClient(), + client=NestedChatClient(), # ty: ignore[invalid-argument-type] id="nested_agent_id", name="nested_agent", - default_options={"model": "NestedModel"}, + default_options={"model": "NestedModel"}, # pyrefly: ignore[bad-argument-type] ) span_exporter.clear() @@ -4112,7 +4143,7 @@ async def _get() -> ChatResponse: spans = span_exporter.get_finished_spans() assert len(spans) == 2 - span_by_op = {s.attributes[OtelAttr.OPERATION.value]: s for s in spans} + span_by_op = {s.attributes[OtelAttr.OPERATION.value]: s for s in spans} # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] agent_span = span_by_op[OtelAttr.AGENT_INVOKE_OPERATION] chat_span = span_by_op[OtelAttr.CHAT_COMPLETION_OPERATION] @@ -4120,12 +4151,17 @@ async def _get() -> ChatResponse: assert agent_span.parent is None # Chat span's parent must be the agent span - assert chat_span.parent is not None - assert chat_span.parent.span_id == agent_span.context.span_id - assert chat_span.parent.trace_id == agent_span.context.trace_id + chat_parent = chat_span.parent + agent_context = agent_span.context + chat_context = chat_span.context + assert chat_parent is not None + assert agent_context is not None + assert chat_context is not None + assert chat_parent.span_id == agent_context.span_id + assert chat_parent.trace_id == agent_context.trace_id # Both spans must share the same trace - assert chat_span.context.trace_id == agent_span.context.trace_id + assert chat_context.trace_id == agent_context.trace_id @pytest.mark.parametrize("stream", [False, True]) @@ -4146,8 +4182,13 @@ def __init__(self) -> None: def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: self.call_count += 1 is_first = self.call_count == 1 @@ -4202,10 +4243,10 @@ async def _get() -> ChatResponse: return _get() agent = Agent( - client=NestedToolChatClient(), + client=NestedToolChatClient(), # ty: ignore[invalid-argument-type] id="tool_agent_id", name="tool_agent", - default_options={"model": "ToolModel", "tools": [get_weather], "tool_choice": "auto"}, + default_options={"model": "ToolModel", "tools": [get_weather], "tool_choice": "auto"}, # pyrefly: ignore[bad-argument-type] ) span_exporter.clear() @@ -4219,9 +4260,9 @@ async def _get() -> ChatResponse: spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] - chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] - tool_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION] + invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + tool_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.TOOL_EXECUTION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(invoke_spans) == 1, f"Expected 1 invoke_agent span, got {len(invoke_spans)}" assert len(chat_spans) == 2, f"Expected 2 chat spans, got {len(chat_spans)}" @@ -4231,12 +4272,17 @@ async def _get() -> ChatResponse: assert agent_span.parent is None # All inner spans must be parented under the agent invoke span + agent_context = agent_span.context + assert agent_context is not None for inner in (*chat_spans, *tool_spans): - assert inner.parent is not None, f"Span {inner.name} has no parent" - assert inner.parent.span_id == agent_span.context.span_id, ( - f"Span {inner.name} parent={inner.parent.span_id} != agent={agent_span.context.span_id}" + inner_parent = inner.parent + inner_context = inner.context + assert inner_parent is not None, f"Span {inner.name} has no parent" + assert inner_context is not None + assert inner_parent.span_id == agent_context.span_id, ( + f"Span {inner.name} parent={inner_parent.span_id} != agent={agent_context.span_id}" ) - assert inner.context.trace_id == agent_span.context.trace_id + assert inner_context.trace_id == agent_context.trace_id @pytest.mark.parametrize("stream", [False, True]) @@ -4263,13 +4309,16 @@ async def test_chat_span_nested_under_explicit_outer_span( await client.get_response(messages=[Message(role="user", contents=["Test"])], options={"model": "Test"}) spans = span_exporter.get_finished_spans() - chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(chat_spans) == 1 chat_span = chat_spans[0] - assert chat_span.parent is not None - assert chat_span.parent.span_id == outer_ctx.span_id - assert chat_span.context.trace_id == outer_ctx.trace_id + chat_parent = chat_span.parent + chat_context = chat_span.context + assert chat_parent is not None + assert chat_context is not None + assert chat_parent.span_id == outer_ctx.span_id + assert chat_context.trace_id == outer_ctx.trace_id @pytest.mark.parametrize("stream", [False, True]) @@ -4288,8 +4337,13 @@ class HttpEmittingClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: if stream: @@ -4328,7 +4382,7 @@ async def _get() -> ChatResponse: await client.get_response(messages=[Message(role="user", contents=["Test"])], options={"model": "Test"}) spans = span_exporter.get_finished_spans() - chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] http_spans = [s for s in spans if s.name == "HTTP POST"] assert len(chat_spans) == 1 assert len(http_spans) == 1 @@ -4336,9 +4390,14 @@ async def _get() -> ChatResponse: chat_span = chat_spans[0] http_span = http_spans[0] - assert http_span.parent is not None - assert http_span.parent.span_id == chat_span.context.span_id - assert http_span.context.trace_id == chat_span.context.trace_id + http_parent = http_span.parent + http_context = http_span.context + chat_context = chat_span.context + assert http_parent is not None + assert http_context is not None + assert chat_context is not None + assert http_parent.span_id == chat_context.span_id + assert http_context.trace_id == chat_context.trace_id # region Test ResponseStream.with_pull_context_manager @@ -4451,8 +4510,13 @@ class FailingClient(ChatTelemetryLayer, BaseChatClient[Any]): def service_url(self): return "https://test.example.com" - def _inner_get_response( - self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + def _inner_get_response( # pyrefly: ignore[bad-override] + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, # type: ignore[override] ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: raise RuntimeError("inner failed") @@ -4462,7 +4526,7 @@ def _inner_get_response( client.get_response(stream=True, messages=[Message(role="user", contents=["Test"])], options={"model": "Test"}) spans = span_exporter.get_finished_spans() - chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(chat_spans) == 1 assert chat_spans[0].status.status_code == StatusCode.ERROR @@ -4508,14 +4572,14 @@ def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): raise RuntimeError("execute failed") raise NotImplementedError - class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent): + class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent): # type: ignore[misc] pass # Sentinel values to detect that contextvars were reset to their pre-call state. sentinel_fields: set[str] = set() sentinel_usage: dict[str, Any] = {} fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(sentinel_fields) - usage_token = INNER_ACCUMULATED_USAGE.set(sentinel_usage) + usage_token = INNER_ACCUMULATED_USAGE.set(sentinel_usage) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] try: agent = FailingExecuteAgent() span_exporter.clear() @@ -4530,7 +4594,7 @@ class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent): INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(fields_token) spans = span_exporter.get_finished_spans() - agent_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + agent_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert len(agent_spans) == 1 assert agent_spans[0].status.status_code == StatusCode.ERROR diff --git a/python/packages/core/tests/core/test_optional_dependencies.py b/python/packages/core/tests/core/test_optional_dependencies.py index 7c424b3454c..0e6e8cc3429 100644 --- a/python/packages/core/tests/core/test_optional_dependencies.py +++ b/python/packages/core/tests/core/test_optional_dependencies.py @@ -128,7 +128,7 @@ def _import_without_mcp( monkeypatch.setattr(builtins, "__import__", _import_without_mcp) - agent = Agent(client=client) + agent = Agent(client=client) # type: ignore[arg-type] with pytest.raises(ModuleNotFoundError, match=r"Please install `mcp`\.$"): agent.as_mcp_server() diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 7c78dba2091..a15f4793ee4 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -174,7 +174,7 @@ def test_get_messages_filter_sources(self) -> None: ctx = SessionContext(input_messages=[]) ctx.extend_messages("a", [Message(role="user", contents=["a"])]) ctx.extend_messages("b", [Message(role="user", contents=["b"])]) - result = ctx.get_messages(sources=["a"]) + result = ctx.get_messages(sources=["a"]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert len(result) == 1 assert result[0].text == "a" @@ -182,7 +182,7 @@ def test_get_messages_exclude_sources(self) -> None: ctx = SessionContext(input_messages=[]) ctx.extend_messages("a", [Message(role="user", contents=["a"])]) ctx.extend_messages("b", [Message(role="user", contents=["b"])]) - result = ctx.get_messages(exclude_sources=["a"]) + result = ctx.get_messages(exclude_sources=["a"]) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert len(result) == 1 assert result[0].text == "b" @@ -233,13 +233,13 @@ async def test_before_run_is_noop(self) -> None: session = AgentSession() ctx = SessionContext(input_messages=[]) # Should not raise - await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] async def test_after_run_is_noop(self) -> None: provider = ContextProvider(source_id="test") session = AgentSession() ctx = SessionContext(input_messages=[]) - await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] # --------------------------------------------------------------------------- @@ -289,7 +289,7 @@ async def test_before_run_loads_messages(self) -> None: provider = ConcreteHistoryProvider("mem", stored_messages=msgs) session = AgentSession() ctx = SessionContext(session_id="s1", input_messages=[]) - await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] + await provider.before_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert len(ctx.context_messages["mem"]) == 1 assert ctx.context_messages["mem"][0].text == "history" @@ -302,7 +302,7 @@ async def test_after_run_stores_inputs_and_responses(self) -> None: resp_msg = Message(role="assistant", contents=["hi"]) ctx = SessionContext(session_id="s1", input_messages=[input_msg]) ctx._response = AgentResponse(messages=[resp_msg]) - await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=session, context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert len(provider.stored) == 2 assert provider.stored[0].text == "hello" assert provider.stored[1].text == "hi" @@ -352,7 +352,7 @@ async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None: ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])]) ctx._response = AgentResponse.from_updates(updates) - await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert len(provider.stored) == 1 stored_contents = provider.stored[0].contents @@ -370,7 +370,7 @@ async def test_after_run_skips_inputs_when_disabled(self) -> None: provider = ConcreteHistoryProvider("mem", store_inputs=False) ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["hello"])]) ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hi"])]) - await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert len(provider.stored) == 1 assert provider.stored[0].text == "hi" @@ -380,7 +380,7 @@ async def test_after_run_skips_responses_when_disabled(self) -> None: provider = ConcreteHistoryProvider("mem", store_outputs=False) ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["hello"])]) ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hi"])]) - await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert len(provider.stored) == 1 assert provider.stored[0].text == "hello" @@ -391,7 +391,7 @@ async def test_after_run_stores_context_messages(self) -> None: ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["hello"])]) ctx.extend_messages("rag", [Message(role="system", contents=["context"])]) ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hi"])]) - await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] # Should store: context from rag + input + response texts = [m.text for m in provider.stored] assert "context" in texts @@ -408,7 +408,7 @@ async def test_after_run_stores_context_from_specific_sources(self) -> None: ctx.extend_messages("rag", [Message(role="system", contents=["rag-context"])]) ctx.extend_messages("other", [Message(role="system", contents=["other-context"])]) ctx._response = AgentResponse(messages=[]) - await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] + await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] texts = [m.text for m in provider.stored] assert "rag-context" in texts assert "other-context" not in texts @@ -483,7 +483,7 @@ async def test_empty_state_returns_no_messages(self) -> None: session = AgentSession() ctx = SessionContext(session_id="s1", input_messages=[]) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}), @@ -501,14 +501,14 @@ async def test_stores_and_loads_messages(self) -> None: resp_msg = Message(role="assistant", contents=["hi there"]) ctx1 = SessionContext(session_id="s1", input_messages=[input_msg]) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx1, state=session.state.setdefault(provider.source_id, {}), ) ctx1._response = AgentResponse(messages=[resp_msg]) await provider.after_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx1, state=session.state.setdefault(provider.source_id, {}), @@ -517,7 +517,7 @@ async def test_stores_and_loads_messages(self) -> None: # Second run: should load previous messages ctx2 = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["again"])]) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx2, state=session.state.setdefault(provider.source_id, {}), @@ -536,14 +536,14 @@ async def test_state_is_serializable(self) -> None: input_msg = Message(role="user", contents=["test"]) ctx = SessionContext(session_id="s1", input_messages=[input_msg]) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}), ) ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])]) await provider.after_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}), @@ -573,8 +573,8 @@ async def test_source_id_attribution(self) -> None: class TestFileHistoryProvider: def test_is_marked_experimental(self) -> None: - assert FileHistoryProvider.__feature_stage__ == "experimental" - assert FileHistoryProvider.__feature_id__ == ExperimentalFeature.FILE_HISTORY.value + assert FileHistoryProvider.__feature_stage__ == "experimental" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert FileHistoryProvider.__feature_id__ == ExperimentalFeature.FILE_HISTORY.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert FileHistoryProvider.__doc__ is not None assert ".. warning:: Experimental" in FileHistoryProvider.__doc__ @@ -589,14 +589,14 @@ async def test_stores_and_loads_messages(self, tmp_path: Path) -> None: first_context = SessionContext(session_id=session.session_id, input_messages=[input_message]) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=first_context, state={}, ) first_context._response = AgentResponse(messages=[response_message]) await provider.after_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=first_context, state={}, @@ -615,7 +615,7 @@ async def test_stores_and_loads_messages(self, tmp_path: Path) -> None: session_id=session.session_id, input_messages=[Message(role="user", contents=["again"])] ) await provider.before_run( # type: ignore[arg-type] - agent=None, + agent=None, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] session=session, context=second_context, state={}, @@ -742,7 +742,7 @@ class _TrackingFile: def __init__(self, wrapped: Any) -> None: self._wrapped = wrapped - def __enter__(self) -> "_TrackingFile": + def __enter__(self) -> "_TrackingFile": # type: ignore[name-defined] self._wrapped.__enter__() return self diff --git a/python/packages/core/tests/core/test_skills.py b/python/packages/core/tests/core/test_skills.py index 31f679c367f..12d636ef414 100644 --- a/python/packages/core/tests/core/test_skills.py +++ b/python/packages/core/tests/core/test_skills.py @@ -950,9 +950,9 @@ def test_feature_metadata_is_set(self) -> None: assert len(set(feature_ids)) == 1 assert getattr(SkillScriptRunner, "__feature_stage__", None) is None assert getattr(SkillScriptRunner, "__feature_id__", None) is None - assert SkillScript.parameters_schema.fget is not None - assert not hasattr(SkillScript.parameters_schema.fget, "__feature_stage__") - assert not hasattr(SkillScript.parameters_schema.fget, "__feature_id__") + assert SkillScript.parameters_schema.fget is not None # type: ignore[attr-defined] + assert not hasattr(SkillScript.parameters_schema.fget, "__feature_stage__") # type: ignore[attr-defined] + assert not hasattr(SkillScript.parameters_schema.fget, "__feature_id__") # type: ignore[attr-defined] class TestSkillResource: @@ -1447,7 +1447,7 @@ async def test_custom_resource_extensions(self, tmp_path: Path) -> None: provider = SkillsProvider.from_paths(str(tmp_path), resource_extensions=(".json",)) await _init_provider(provider) skill = _ctx(provider)[0]["my-skill"] - resource_names = [r.name for r in skill._resources] + resource_names = [r.name for r in skill._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "references/data.json" in resource_names assert "references/notes.txt" not in resource_names @@ -1709,7 +1709,7 @@ async def test_custom_resource_directories(self, tmp_path: Path) -> None: source = FileSkillsSource(str(tmp_path), resource_directories=["docs"]) skills = await source.get_skills() - resource_names = [r.name for r in skills[0]._resources] + resource_names = [r.name for r in skills[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "docs/guide.md" in resource_names assert "references/ref.md" not in resource_names @@ -1728,7 +1728,7 @@ async def test_custom_script_directories(self, tmp_path: Path) -> None: source = FileSkillsSource(str(tmp_path), script_directories=["tools"]) skills = await source.get_skills() - script_names = [s.name for s in skills[0]._scripts] + script_names = [s.name for s in skills[0]._scripts] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "tools/run.py" in script_names async def test_root_indicator_discovers_root_files(self, tmp_path: Path) -> None: @@ -1743,7 +1743,7 @@ async def test_root_indicator_discovers_root_files(self, tmp_path: Path) -> None source = FileSkillsSource(str(tmp_path), resource_directories=[".", "references"]) skills = await source.get_skills() - resource_names = [r.name for r in skills[0]._resources] + resource_names = [r.name for r in skills[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "data.json" in resource_names async def test_from_paths_passes_directories(self, tmp_path: Path) -> None: @@ -1763,7 +1763,7 @@ async def test_from_paths_passes_directories(self, tmp_path: Path) -> None: ) await _init_provider(provider) skill = _ctx(provider)[0]["my-skill"] - resource_names = [r.name for r in skill._resources] + resource_names = [r.name for r in skill._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "docs/guide.md" in resource_names @@ -3029,7 +3029,7 @@ def runner(skill: Skill, script: SkillScript, args: dict[str, Any] | None = None captured["args"] = args return "runner_result" - script = FileSkillScript(name="run.py", full_path=f"{_ABS}/test/run.py", runner=runner) + script = FileSkillScript(name="run.py", full_path=f"{_ABS}/test/run.py", runner=runner) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] skill = FileSkill( frontmatter=SkillFrontmatter(name="my-skill", description="d"), content="c", path=f"{_ABS}/test" ) @@ -3043,7 +3043,7 @@ async def test_run_file_based_with_async_runner(self) -> None: async def runner(skill: Skill, script: SkillScript, args: dict[str, Any] | None = None) -> str: return "async_runner" - script = FileSkillScript(name="run.py", full_path=f"{_ABS}/test/run.py", runner=runner) + script = FileSkillScript(name="run.py", full_path=f"{_ABS}/test/run.py", runner=runner) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] skill = FileSkill(frontmatter=SkillFrontmatter(name="s", description="d"), content="c", path=f"{_ABS}/test") result = await script.run(skill, args=None) assert result == "async_runner" @@ -3189,7 +3189,7 @@ async def my_runner(skill, script, args=None): script = FileSkillScript(name="my-script", full_path=f"{_ABS}/test/scripts/run.py") skill._scripts.append(script) - result = await my_runner(skill, script, args={"key": "val"}) + result = await my_runner(skill, script, args={"key": "val"}) # pyrefly: ignore[bad-argument-type] assert result == "executed" assert len(results) == 1 @@ -3207,7 +3207,7 @@ async def __call__(self, skill, script, args=None): script = InlineSkillScript(name="my-script", function=lambda: None) skill._scripts.append(script) - result = await runner(skill, script, args={"key": "val"}) + result = await runner(skill, script, args={"key": "val"}) # type: ignore[arg-type] assert result == "custom result" async def test_runner_returns_none(self) -> None: @@ -3243,7 +3243,7 @@ def my_runner(skill, script, args=None): script = FileSkillScript(name="my-script", full_path=f"{_ABS}/test/scripts/run.py") skill._scripts.append(script) - result = my_runner(skill, script, args={"key": "val"}) + result = my_runner(skill, script, args={"key": "val"}) # pyrefly: ignore[bad-argument-type] assert result == "executed" assert len(results) == 1 @@ -3261,7 +3261,7 @@ def __call__(self, skill, script, args=None): script = InlineSkillScript(name="my-script", function=lambda: None) skill._scripts.append(script) - result = runner(skill, script, args={"key": "val"}) + result = runner(skill, script, args={"key": "val"}) # type: ignore[arg-type] assert result == "sync result" def test_sync_runner_returns_none(self) -> None: @@ -3676,8 +3676,8 @@ async def test_instructions_include_script_runner_hints(self) -> None: provider = SkillsProvider([skill]) await _init_provider(provider) - assert "run_skill_script" in _ctx(provider)[1] - assert "not as top-level tool parameters" in _ctx(provider)[1] + assert "run_skill_script" in _ctx(provider)[1] # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "not as top-level tool parameters" in _ctx(provider)[1] # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] async def test_no_scripts_no_runner_no_script_instructions(self) -> None: skill = InlineSkill(frontmatter=SkillFrontmatter(name="my-skill", description="test"), instructions="body") @@ -3810,10 +3810,10 @@ async def test_discovered_script_has_absolute_full_path(self, tmp_path: Path) -> skills = await _discover_file_skills_for_test(str(tmp_path)) script = skills["my-skill"]._scripts[0] - assert script.full_path is not None - assert os.path.isabs(script.full_path) + assert script.full_path is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert os.path.isabs(script.full_path) # type: ignore[attr-defined] # pyrefly: ignore[bad-argument-type] # ty: ignore[unresolved-attribute] expected = str(Path(str(skill_dir), "scripts", "generate.py")) - assert script.full_path == expected + assert script.full_path == expected # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_scripts_not_discovered_recursively(self, tmp_path: Path) -> None: """Scripts inside subdirectories of scripts/ are NOT discovered (non-recursive).""" @@ -3895,7 +3895,7 @@ async def test_custom_script_extensions_via_provider(self, tmp_path: Path) -> No ) await _init_provider(provider) skill = _ctx(provider)[0]["my-skill"] - script_names = [s.name for s in skill._scripts] + script_names = [s.name for s in skill._scripts] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "scripts/run.sh" in script_names assert "scripts/analyze.py" not in script_names @@ -3919,7 +3919,7 @@ async def test_multiple_script_extensions(self, tmp_path: Path) -> None: ) await _init_provider(provider) skill = _ctx(provider)[0]["my-skill"] - script_names = [s.name for s in skill._scripts] + script_names = [s.name for s in skill._scripts] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "scripts/analyze.py" in script_names assert "scripts/run.sh" in script_names assert "scripts/notes.txt" not in script_names @@ -4498,7 +4498,7 @@ def __init__(self) -> None: def instructions(self) -> str: return "x" - @ClassSkill.resource(name="oops") # wrong: should be below @property + @ClassSkill.resource(name="oops") # type: ignore[prop-decorator] # wrong: should be below @property @property def bad_prop(self) -> str: return "x" @@ -4515,7 +4515,7 @@ def __init__(self) -> None: def instructions(self) -> str: return "x" - @ClassSkill.script(name="oops") + @ClassSkill.script(name="oops") # type: ignore[prop-decorator] @property def bad_prop(self) -> str: return "x" @@ -5058,7 +5058,7 @@ async def test_file_skill_takes_precedence_over_code_skill(self, tmp_path: Path) result = await source.get_skills() skills_by_name = {s.frontmatter.name: s for s in result} assert "my-skill" in skills_by_name - assert skills_by_name["my-skill"].path is not None # file-based skill has path set + assert skills_by_name["my-skill"].path is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # file-based skill has path set # --------------------------------------------------------------------------- @@ -5082,7 +5082,7 @@ async def test_file_skills_source_discovers_skills(self, tmp_path: Path) -> None skills = await source.get_skills() assert len(skills) == 1 assert skills[0].frontmatter.name == "my-skill" - assert skills[0].path is not None + assert skills[0].path is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_file_skills_source_with_extensions(self, tmp_path: Path) -> None: """FileSkillsSource resource_extensions controls extension filtering.""" @@ -5100,7 +5100,7 @@ async def test_file_skills_source_with_extensions(self, tmp_path: Path) -> None: source = FileSkillsSource(str(tmp_path), resource_extensions=(".json",)) skills = await source.get_skills() assert len(skills) == 1 - resource_names = [r.name for r in skills[0]._resources] + resource_names = [r.name for r in skills[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "references/data.json" in resource_names assert "references/data.csv" not in resource_names @@ -5568,12 +5568,12 @@ async def test_template_missing_skills_placeholder_raises(self) -> None: def test_string_source_rejected_with_helpful_error(self) -> None: """Passing a string (path) to SkillsProvider raises TypeError.""" with pytest.raises(TypeError, match="from_paths"): - SkillsProvider("./skills") # type: ignore[arg-type] + SkillsProvider("./skills") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def test_path_source_rejected_with_helpful_error(self) -> None: """Passing a Path to SkillsProvider raises TypeError.""" with pytest.raises(TypeError, match="from_paths"): - SkillsProvider(Path("./skills")) # type: ignore[arg-type] + SkillsProvider(Path("./skills")) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] # --------------------------------------------------------------------------- diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f44cbc267ac..023d293cdfa 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -550,12 +550,12 @@ def telemetry_test_tool(x: int, y: int) -> int: span = spans[0] assert OtelAttr.TOOL_EXECUTION_OPERATION.value in span.name assert "telemetry_test_tool" in span.name - assert span.attributes[OtelAttr.TOOL_NAME] == "telemetry_test_tool" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool for telemetry" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 1, "y": 2}' - assert span.attributes[OtelAttr.TOOL_RESULT] == "3" + assert span.attributes[OtelAttr.TOOL_NAME] == "telemetry_test_tool" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool for telemetry" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 1, "y": 2}' # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_RESULT] == "3" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # Verify histogram was called with correct attributes mock_histogram.record.assert_called_once() @@ -595,12 +595,12 @@ def telemetry_test_tool(x: int, y: int) -> int: span = spans[0] assert OtelAttr.TOOL_EXECUTION_OPERATION.value in span.name assert "telemetry_test_tool" in span.name - assert span.attributes[OtelAttr.TOOL_NAME] == "telemetry_test_tool" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool for telemetry" - assert OtelAttr.TOOL_ARGUMENTS not in span.attributes - assert OtelAttr.TOOL_RESULT not in span.attributes + assert span.attributes[OtelAttr.TOOL_NAME] == "telemetry_test_tool" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool for telemetry" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert OtelAttr.TOOL_ARGUMENTS not in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert OtelAttr.TOOL_RESULT not in span.attributes # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Verify histogram was called with correct attributes mock_histogram.record.assert_called_once() @@ -619,7 +619,7 @@ async def simple_tool(message: str) -> str: """Echo tool.""" return message.upper() - args = simple_tool.input_model(message="hello world") + args = simple_tool.input_model(message="hello world") # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] with pytest.raises(TypeError, match="Unexpected keyword argument"): await simple_tool.invoke( @@ -641,7 +641,7 @@ def pydantic_test_tool(x: int, y: int) -> int: return x + y # Create arguments as Pydantic model instance - args_model = pydantic_test_tool.input_model(x=5, y=10) + args_model = pydantic_test_tool.input_model(x=5, y=10) # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] mock_histogram = Mock() pydantic_test_tool._invocation_duration_histogram = mock_histogram @@ -657,11 +657,11 @@ def pydantic_test_tool(x: int, y: int) -> int: span = spans[0] assert OtelAttr.TOOL_EXECUTION_OPERATION.value in span.name assert "pydantic_test_tool" in span.name - assert span.attributes[OtelAttr.TOOL_NAME] == "pydantic_test_tool" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}' + assert span.attributes[OtelAttr.TOOL_NAME] == "pydantic_test_tool" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}' # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_tool_invoke_telemetry_with_exception(span_exporter: InMemorySpanExporter): @@ -686,12 +686,12 @@ def exception_test_tool(x: int, y: int) -> int: span = spans[0] assert OtelAttr.TOOL_EXECUTION_OPERATION.value in span.name assert "exception_test_tool" in span.name - assert span.attributes[OtelAttr.TOOL_NAME] == "exception_test_tool" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "exception_call" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool that raises an exception" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 1, "y": 2}' - assert span.attributes[OtelAttr.ERROR_TYPE] == ValueError.__name__ + assert span.attributes[OtelAttr.TOOL_NAME] == "exception_test_tool" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "exception_call" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool that raises an exception" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 1, "y": 2}' # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.ERROR_TYPE] == ValueError.__name__ # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] assert span.status.status_code == trace.StatusCode.ERROR # Verify histogram was called with error attributes @@ -726,11 +726,11 @@ async def async_telemetry_test(x: int, y: int) -> int: span = spans[0] assert OtelAttr.TOOL_EXECUTION_OPERATION.value in span.name assert "async_telemetry_test" in span.name - assert span.attributes[OtelAttr.TOOL_NAME] == "async_telemetry_test" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "async_call" - assert span.attributes[OtelAttr.TOOL_TYPE] == "function" - assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "An async test tool for telemetry" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 3, "y": 4}' + assert span.attributes[OtelAttr.TOOL_NAME] == "async_telemetry_test" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "async_call" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_TYPE] == "function" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "An async test tool for telemetry" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 3, "y": 4}' # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # Verify histogram recording mock_histogram.record.assert_called_once() @@ -822,7 +822,7 @@ def test_parse_inputs_list_of_strings(): """Test _parse_inputs with list of strings.""" inputs = ["http://example.com", "https://test.org"] - result = _parse_inputs(inputs) + result = _parse_inputs(inputs) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert len(result) == 2 assert all(item.type == "uri" for item in result) @@ -900,7 +900,7 @@ def test_parse_inputs_mixed_list(): Content.from_text(text="Hello"), # Content instance ] - result = _parse_inputs(inputs) + result = _parse_inputs(inputs) # type: ignore[arg-type] assert len(result) == 4 assert result[0].type == "uri" @@ -925,7 +925,7 @@ def test_parse_inputs_unsupported_dict(): def test_parse_inputs_unsupported_type(): """Test _parse_inputs with unsupported input type.""" with pytest.raises(TypeError, match="Unsupported input type: int"): - _parse_inputs(123) + _parse_inputs(123) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # endregion @@ -953,13 +953,13 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: with pytest.raises(TypeError, match="Unexpected keyword argument"): await tool_with_kwargs.invoke( - arguments=tool_with_kwargs.input_model(x=5), + arguments=tool_with_kwargs.input_model(x=5), # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] user_id="user2", ) # Verify invoke works without injected args (uses default) result_default = await tool_with_kwargs.invoke( - arguments=tool_with_kwargs.input_model(x=10), + arguments=tool_with_kwargs.input_model(x=10), # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] ) assert isinstance(result_default, list) assert result_default[0].text == "x=10, user=unknown" @@ -983,7 +983,7 @@ def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str: context = FunctionInvocationContext( function=tool_with_context, - arguments=tool_with_context.input_model(x=7), + arguments=tool_with_context.input_model(x=7), # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] kwargs={"user_id": "ctx-user"}, ) @@ -1010,7 +1010,7 @@ def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str context = FunctionInvocationContext( function=tool_with_runtime_context, - arguments=tool_with_runtime_context.input_model(x=8), + arguments=tool_with_runtime_context.input_model(x=8), # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] kwargs={"user_id": "runtime-user"}, ) @@ -1132,7 +1132,9 @@ def test_parse_annotation_with_annotated_and_literal(): """Test that Annotated[Literal[...], description] works correctly.""" # When Literal is inside Annotated, it should still be preserved - annotated_literal = Annotated[Literal["A", "B", "C"], "The category"] + category_literal: Any = Literal["A", "B", "C"] + annotated: Any = Annotated + annotated_literal = annotated[category_literal, "The category"] result = _parse_annotation(annotated_literal) # The Annotated type should be preserved @@ -1468,9 +1470,9 @@ def returns_dict(x: int) -> dict[str, int]: spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] - assert span.attributes[OtelAttr.TOOL_NAME] == "raw_tool" - assert span.attributes[OtelAttr.TOOL_CALL_ID] == "raw_call" - assert span.attributes[OtelAttr.TOOL_RESULT] == "{'value': 5}" + assert span.attributes[OtelAttr.TOOL_NAME] == "raw_tool" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "raw_call" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert span.attributes[OtelAttr.TOOL_RESULT] == "{'value': 5}" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] async def test_invoke_default_path_records_parsed_telemetry( @@ -1492,7 +1494,7 @@ def returns_int() -> int: assert parsed[0].text == "parsed:7" spans = span_exporter.get_finished_spans() assert len(spans) == 1 - assert spans[0].attributes[OtelAttr.TOOL_RESULT] == "parsed:7" + assert spans[0].attributes[OtelAttr.TOOL_RESULT] == "parsed:7" # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] def test_skip_parsing_is_singleton() -> None: diff --git a/python/packages/core/tests/core/test_tools_future_annotations.py b/python/packages/core/tests/core/test_tools_future_annotations.py index 1c9649dcb9c..26e5ea65c17 100644 --- a/python/packages/core/tests/core/test_tools_future_annotations.py +++ b/python/packages/core/tests/core/test_tools_future_annotations.py @@ -127,7 +127,7 @@ def get_weather(location: str, ctx: FunctionInvocationContext) -> str: context = FunctionInvocationContext( function=get_weather, - arguments=get_weather.input_model(location="Seattle"), + arguments=get_weather.input_model(location="Seattle"), # type: ignore[misc, operator] # pyrefly: ignore[not-callable] # ty: ignore[call-non-callable] kwargs={"user": "test_user"}, ) result = await get_weather.invoke(context=context) diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 62e88fab020..b71aedb0587 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, Sequence from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any, Literal, cast import pytest from pydantic import BaseModel, Field, ValidationError @@ -125,8 +125,8 @@ def test_data_content_bytes(): # Check the type and content assert content.type == "data" assert content.uri == "data:application/octet-stream;base64,dGVzdA==" - assert content.media_type.startswith("application/") is True - assert content.media_type.startswith("image/") is False + assert content.media_type.startswith("application/") is True # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert content.media_type.startswith("image/") is False # type: ignore[union-attr] # ty: ignore[unresolved-attribute] assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent @@ -212,7 +212,7 @@ def test_data_content_create_data_uri_from_base64(): """Test the create_data_uri_from_base64 class method.""" # Test with PNG data png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" - content = Content.from_data(png_data, media_type=detect_media_type_from_base64(data_bytes=png_data)) + content = Content.from_data(png_data, media_type=detect_media_type_from_base64(data_bytes=png_data)) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert content.uri == f"data:image/png;base64,{base64.b64encode(png_data).decode()}" assert content.media_type == "image/png" @@ -220,7 +220,7 @@ def test_data_content_create_data_uri_from_base64(): # Test with different format jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" jpeg_base64 = base64.b64encode(jpeg_data).decode() - content = Content.from_data(jpeg_data, media_type=detect_media_type_from_base64(data_bytes=jpeg_data)) + content = Content.from_data(jpeg_data, media_type=detect_media_type_from_base64(data_bytes=jpeg_data)) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert content.uri == f"data:image/jpeg;base64,{jpeg_base64}" assert content.media_type == "image/jpeg" @@ -568,39 +568,39 @@ def test_usage_details(): def test_usage_details_addition(): - usage1 = UsageDetails( + usage1 = UsageDetails( # type: ignore[typeddict-unknown-key] input_token_count=5, output_token_count=10, total_token_count=15, - test1=10, - test2=20, + test1=10, # ty: ignore[invalid-key] + test2=20, # ty: ignore[invalid-key] ) - usage2 = UsageDetails( + usage2 = UsageDetails( # type: ignore[typeddict-unknown-key] input_token_count=3, output_token_count=6, total_token_count=9, - test1=10, - test3=30, + test1=10, # ty: ignore[invalid-key] + test3=30, # ty: ignore[invalid-key] ) combined_usage = add_usage_details(usage1, usage2) assert combined_usage["input_token_count"] == 8 assert combined_usage["output_token_count"] == 16 assert combined_usage["total_token_count"] == 24 - assert combined_usage["test1"] == 20 - assert combined_usage["test2"] == 20 - assert combined_usage["test3"] == 30 + assert combined_usage["test1"] == 20 # type: ignore[typeddict-item] # ty: ignore[invalid-key] + assert combined_usage["test2"] == 20 # type: ignore[typeddict-item] # ty: ignore[invalid-key] + assert combined_usage["test3"] == 30 # type: ignore[typeddict-item] # ty: ignore[invalid-key] def test_usage_details_fail(): # TypedDict doesn't validate types at runtime, so this test no longer applies # Creating UsageDetails with wrong types won't raise ValueError - usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") - assert usage["wrong_type"] == "42.923" + usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") # type: ignore[typeddict-item, typeddict-unknown-key] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-key] + assert usage["wrong_type"] == "42.923" # type: ignore[typeddict-item] # ty: ignore[invalid-key] def test_usage_details_additional_counts(): - usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, **{"test": 1}) + usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, **{"test": 1}) # type: ignore[call-arg, typeddict-unknown-key] # ty: ignore[invalid-key] assert usage.get("test") == 1 @@ -616,8 +616,8 @@ def test_usage_details_add_with_none_and_type_errors(): def test_usage_details_add_skips_non_int(): - u1 = UsageDetails(input_token_count=10, other="test") - u2 = UsageDetails(input_token_count=10, another="test") + u1 = UsageDetails(input_token_count=10, other="test") # type: ignore[typeddict-item, typeddict-unknown-key] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-key] + u2 = UsageDetails(input_token_count=10, another="test") # type: ignore[typeddict-item, typeddict-unknown-key] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-key] u3 = add_usage_details(u1, u2) assert len(u3.keys()) == 1 assert "input_token_count" in u3 @@ -656,9 +656,9 @@ def test_function_approval_serialization_roundtrip(): # Test that the basic properties match assert loaded.id == req.id assert loaded.additional_properties == req.additional_properties - assert loaded.function_call.call_id == req.function_call.call_id - assert loaded.function_call.name == req.function_call.name - assert loaded.function_call.arguments == req.function_call.arguments + assert loaded.function_call.call_id == req.function_call.call_id # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert loaded.function_call.name == req.function_call.name # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert loaded.function_call.arguments == req.function_call.arguments # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Skip the BaseModel validation test since we're no longer using Pydantic # The Content union will need to be handled differently when we fully migrate @@ -1152,9 +1152,9 @@ def test_chat_options_tool_choice_validation(): assert validate_tool_mode(None) is None with raises(ContentError): - validate_tool_mode("invalid_mode") + validate_tool_mode("invalid_mode") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] with raises(ContentError): - validate_tool_mode({"mode": "invalid_mode"}) + validate_tool_mode({"mode": "invalid_mode"}) # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] with raises(ContentError): validate_tool_mode({"mode": "auto", "required_function_name": "should_not_be_here"}) @@ -1180,11 +1180,11 @@ def test_chat_options_tool_choice_validation(): # allowed_tools must be a non-string sequence of strings with raises(ContentError): - validate_tool_mode({"mode": "auto", "allowed_tools": "get_weather"}) + validate_tool_mode({"mode": "auto", "allowed_tools": "get_weather"}) # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] with raises(ContentError): - validate_tool_mode({"mode": "auto", "allowed_tools": 123}) + validate_tool_mode({"mode": "auto", "allowed_tools": 123}) # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] with raises(ContentError): - validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", 123]}) + validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", 123]}) # type: ignore[arg-type, list-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Empty list is valid (caller explicitly allows no tools) assert validate_tool_mode({"mode": "auto", "allowed_tools": []}) == { @@ -1193,7 +1193,7 @@ def test_chat_options_tool_choice_validation(): } # Tuple is normalized to list - result = validate_tool_mode({"mode": "auto", "allowed_tools": ("get_weather",)}) + result = validate_tool_mode({"mode": "auto", "allowed_tools": ("get_weather",)}) # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert result is not None assert result["allowed_tools"] == ["get_weather"] @@ -1210,7 +1210,7 @@ def test_chat_options_merge(tool_tool, ai_tool) -> None: assert options1 != options2 # Merge options - override takes precedence for non-collection fields - options3 = merge_chat_options(options1, options2) + options3 = merge_chat_options(options1, options2) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert options3.get("model") == "gpt-4.1" assert options3.get("tools") == [tool_tool, ai_tool] # tools are combined @@ -1225,7 +1225,7 @@ def test_chat_options_and_tool_choice_override() -> None: # Run-level specifies "required" run_options: ChatOptions = {"tool_choice": "required"} - merged = merge_chat_options(agent_options, run_options) + merged = merge_chat_options(agent_options, run_options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Run-level should override agent-level assert merged.get("tool_choice") == "required" @@ -1237,7 +1237,7 @@ def test_chat_options_and_tool_choice_none_in_other_uses_self() -> None: agent_options: ChatOptions = {"tool_choice": "auto"} run_options: ChatOptions = {"model": "gpt-4.1"} # tool_choice is None - merged = merge_chat_options(agent_options, run_options) + merged = merge_chat_options(agent_options, run_options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Should keep agent-level tool_choice since run-level is None assert merged.get("tool_choice") == "auto" @@ -1249,7 +1249,7 @@ def test_chat_options_and_tool_choice_with_tool_mode() -> None: agent_options: ChatOptions = {"tool_choice": "auto"} run_options: ChatOptions = {"tool_choice": "required"} - merged = merge_chat_options(agent_options, run_options) + merged = merge_chat_options(agent_options, run_options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert merged.get("tool_choice") == "required" assert merged.get("tool_choice") == "required" @@ -1260,11 +1260,12 @@ def test_chat_options_and_tool_choice_required_specific_function() -> None: agent_options: ChatOptions = {"tool_choice": "auto"} run_options: ChatOptions = {"tool_choice": {"mode": "required", "required_function_name": "get_weather"}} - merged = merge_chat_options(agent_options, run_options) + merged = merge_chat_options(agent_options, run_options) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] tool_choice = merged.get("tool_choice") + assert isinstance(tool_choice, dict) assert tool_choice == {"mode": "required", "required_function_name": "get_weather"} - assert tool_choice["required_function_name"] == "get_weather" + assert tool_choice["required_function_name"] == "get_weather" # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # region Agent Response Fixtures @@ -1381,7 +1382,8 @@ def test_agent_run_response_update_created_at() -> None: created_at=formatted_utc, ) assert update_with_now.created_at == formatted_utc - assert update_with_now.created_at.endswith("Z") + assert update_with_now.created_at is not None + assert update_with_now.created_at.endswith("Z") # ty: ignore[unresolved-attribute] def test_agent_run_response_created_at() -> None: @@ -1403,7 +1405,8 @@ def test_agent_run_response_created_at() -> None: created_at=formatted_utc, ) assert response_with_now.created_at == formatted_utc - assert response_with_now.created_at.endswith("Z") + assert response_with_now.created_at is not None + assert response_with_now.created_at.endswith("Z") # ty: ignore[unresolved-attribute] # region ErrorContent @@ -1702,7 +1705,7 @@ def test_text_reasoning_content_add_conflicting_ids_raises(): t2 = Content.from_text_reasoning(id="rs_xyz789", text=" part 2") with pytest.raises(AdditionItemMismatch, match="different ids"): - t1 + t2 + _ = t1 + t2 def test_text_reasoning_content_add_neither_has_id(): @@ -1755,9 +1758,9 @@ def test_comprehensive_to_dict_exclude_options(): assert "text" in text_dict_exclude # Test UsageDetails - it's a TypedDict now, not a class with to_dict - usage = UsageDetails(input_token_count=5, custom_count=10) + usage = UsageDetails(input_token_count=5, custom_count=10) # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] assert usage["input_token_count"] == 5 - assert usage["custom_count"] == 10 + assert usage["custom_count"] == 10 # type: ignore[typeddict-item] # ty: ignore[invalid-key] # Test UsageDetails exclude_none behavior isn't applicable to TypedDict # TypedDict doesn't have a to_dict method @@ -1766,8 +1769,8 @@ def test_comprehensive_to_dict_exclude_options(): def test_usage_details_iadd_edge_cases(): """Test UsageDetails addition with edge cases for better coverage.""" # Test with None values - u1 = UsageDetails(input_token_count=None, output_token_count=5, custom1=10) - u2 = UsageDetails(input_token_count=3, output_token_count=None, custom2=20) + u1 = UsageDetails(input_token_count=None, output_token_count=5, custom1=10) # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] + u2 = UsageDetails(input_token_count=3, output_token_count=None, custom2=20) # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] result = add_usage_details(u1, u2) assert result["input_token_count"] == 3 @@ -1776,8 +1779,8 @@ def test_usage_details_iadd_edge_cases(): assert result.get("custom2") == 20 # Test merging additional counts - u3 = UsageDetails(input_token_count=1, shared_count=5) - u4 = UsageDetails(input_token_count=2, shared_count=15) + u3 = UsageDetails(input_token_count=1, shared_count=5) # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] + u4 = UsageDetails(input_token_count=2, shared_count=15) # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] result2 = add_usage_details(u3, u4) assert result2["input_token_count"] == 3 @@ -1812,7 +1815,7 @@ def test_text_content_add_type_error(): t1 = Content.from_text("Hello") with raises(TypeError, match="Incompatible type"): - t1 + "not a TextContent" + t1 + "not a TextContent" # type: ignore[operator] # pyrefly: ignore[unsupported-operation] # ty: ignore[unsupported-operator] def test_comprehensive_serialization_methods(): @@ -2007,10 +2010,10 @@ def test_usage_content_serialization_with_details(): "custom_count": 5, }, } - usage_content = Content(**usage_data) + usage_content = Content(**usage_data) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert isinstance(usage_content.usage_details, dict) assert usage_content.usage_details["input_token_count"] == 10 - assert usage_content.usage_details["custom_count"] == 5 # Custom fields go directly in UsageDetails + assert usage_content.usage_details["custom_count"] == 5 # type: ignore[typeddict-item] # ty: ignore[invalid-argument-type, invalid-key] # Custom fields go directly in UsageDetails # Test to_dict with UsageDetails object usage_dict = usage_content.to_dict() @@ -2034,8 +2037,8 @@ def test_function_approval_response_content_serialization(): }, } response_content = Content.from_dict(response_data) - assert response_content.function_call.type == "function_call" - assert response_content.function_call.call_id == "call123" + assert response_content.function_call.type == "function_call" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert response_content.function_call.call_id == "call123" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] # Test to_dict with FunctionCallContent object response_dict = response_content.to_dict() @@ -2618,7 +2621,7 @@ def test_content_roundtrip_serialization(content_class: type[Content], init_kwar assert hasattr(reconstructed, "media_type") assert reconstructed.media_type == init_kwargs.get("media_type") # Verify the uri contains the encoded data - assert reconstructed.uri.startswith(f"data:{init_kwargs.get('media_type')};base64,") + assert reconstructed.uri.startswith(f"data:{init_kwargs.get('media_type')};base64,") # type: ignore[union-attr] # ty: ignore[unresolved-attribute] continue reconstructed_value = getattr(reconstructed, key) @@ -2634,16 +2637,16 @@ def test_content_roundtrip_serialization(content_class: type[Content], init_kwar # Compare each item by serializing the reconstructed object assert len(reconstructed_value) == len(value) for orig_dict, recon_obj in zip(value, reconstructed_value): - recon_dict = recon_obj.to_dict() + recon_dict = recon_obj.to_dict() # ty: ignore[unresolved-attribute] # Compare all keys from original dict (reconstructed may have extra default fields) - for k, v in orig_dict.items(): + for k, v in orig_dict.items(): # ty: ignore[unresolved-attribute] assert k in recon_dict, f"Key '{k}' missing from reconstructed dict" # For nested lists, recursively compare if isinstance(v, list) and v and isinstance(v[0], dict): assert len(recon_dict[k]) == len(v) for orig_item, recon_item in zip(v, recon_dict[k]): # Compare essential keys, ignoring fields like additional_properties - for item_key, item_val in orig_item.items(): + for item_key, item_val in orig_item.items(): # ty: ignore[unresolved-attribute] assert item_key in recon_item assert recon_item[item_key] == item_val else: @@ -2688,12 +2691,12 @@ def test_text_content_with_annotations_serialization(): reconstructed = Content.from_dict(content_dict) # Verify reconstruction - assert len(reconstructed.annotations) == 2 + assert len(reconstructed.annotations) == 2 # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Annotation are TypedDicts (dicts at runtime) - assert all(isinstance(ann, dict) for ann in reconstructed.annotations) - assert reconstructed.annotations[0]["title"] == "Citation 1" - assert reconstructed.annotations[1]["title"] == "Citation 2" - assert all(isinstance(ann["annotated_regions"][0], dict) for ann in reconstructed.annotations) + assert all(isinstance(ann, dict) for ann in reconstructed.annotations) # type: ignore[union-attr] # pyrefly: ignore[not-iterable] + assert reconstructed.annotations[0]["title"] == "Citation 1" # type: ignore[index] # pyrefly: ignore[unsupported-operation] + assert reconstructed.annotations[1]["title"] == "Citation 2" # type: ignore[index] # pyrefly: ignore[unsupported-operation] + assert all(isinstance(ann["annotated_regions"][0], dict) for ann in reconstructed.annotations) # type: ignore[union-attr] # pyrefly: ignore[not-iterable] # region FunctionTool.parse_result with Pydantic models @@ -2721,8 +2724,8 @@ def test_parse_result_pydantic_model(): assert isinstance(parsed, list) assert len(parsed) == 1 assert parsed[0].type == "text" - assert '"temperature": 22.5' in parsed[0].text or '"temperature":22.5' in parsed[0].text - assert '"condition": "sunny"' in parsed[0].text or '"condition":"sunny"' in parsed[0].text + assert '"temperature": 22.5' in parsed[0].text or '"temperature":22.5' in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert '"condition": "sunny"' in parsed[0].text or '"condition":"sunny"' in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_result_pydantic_model_in_list(): @@ -2736,9 +2739,9 @@ def test_parse_result_pydantic_model_in_list(): assert isinstance(parsed, list) assert len(parsed) == 1 assert parsed[0].type == "text" - assert parsed[0].text.startswith("[") - assert "cloudy" in parsed[0].text - assert "sunny" in parsed[0].text + assert parsed[0].text.startswith("[") # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert "cloudy" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "sunny" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_result_pydantic_model_in_dict(): @@ -2752,10 +2755,10 @@ def test_parse_result_pydantic_model_in_dict(): assert isinstance(parsed, list) assert len(parsed) == 1 assert parsed[0].type == "text" - assert "current" in parsed[0].text - assert "forecast" in parsed[0].text - assert "partly cloudy" in parsed[0].text - assert "sunny" in parsed[0].text + assert "current" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "forecast" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "partly cloudy" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "sunny" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] def test_parse_result_nested_pydantic_model(): @@ -2766,9 +2769,9 @@ def test_parse_result_nested_pydantic_model(): assert isinstance(parsed, list) assert len(parsed) == 1 assert parsed[0].type == "text" - assert "Seattle" in parsed[0].text - assert "rainy" in parsed[0].text - assert "18.0" in parsed[0].text or "18" in parsed[0].text + assert "Seattle" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "rainy" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] + assert "18.0" in parsed[0].text or "18" in parsed[0].text # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # region FunctionTool.parse_result with MCP TextContent-like objects @@ -3006,8 +3009,8 @@ def test_content_add_usage_content(): result = usage1 + usage2 assert result.type == "usage" - assert result.usage_details["input_token_count"] == 300 - assert result.usage_details["output_token_count"] == 150 + assert result.usage_details["input_token_count"] == 300 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] + assert result.usage_details["output_token_count"] == 150 # type: ignore[index] # pyrefly: ignore[unsupported-operation] # ty: ignore[not-subscriptable] # Raw representations should be combined assert isinstance(result.raw_representation, list) assert "raw1" in result.raw_representation @@ -3036,19 +3039,19 @@ def test_content_add_usage_content_non_integer_values(): """Test adding usage content with non-integer values.""" usage1 = Content( type="usage", - usage_details={"model": "gpt-4", "count": 10}, + usage_details={"model": "gpt-4", "count": 10}, # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type, invalid-key] ) usage2 = Content( type="usage", - usage_details={"model": "gpt-3.5", "count": 20}, + usage_details={"model": "gpt-3.5", "count": 20}, # type: ignore[arg-type, typeddict-item] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type, invalid-key] ) result = usage1 + usage2 # Non-integer "model" should take first non-None value - assert "model" not in result.usage_details + assert "model" not in result.usage_details # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] # Integer "count" should be summed - assert result.usage_details["count"] == 30 + assert result.usage_details["count"] == 30 # type: ignore[index, typeddict-item] # pyrefly: ignore[unsupported-operation] # ty: ignore[invalid-key, not-subscriptable] # endregion @@ -3062,7 +3065,7 @@ def test_content_has_top_level_media_type(): image = Content(type="uri", uri="https://example.com/image.png", media_type="image/png") assert image.has_top_level_media_type("image") is True - assert image.has_top_level_media_type("IMAGE") is True # Case insensitive + assert image.has_top_level_media_type("IMAGE") is True # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Case insensitive assert image.has_top_level_media_type("audio") is False @@ -3333,7 +3336,7 @@ def tracking_hook(response: ChatResponse) -> ChatResponse: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[tracking_hook], + result_hooks=[tracking_hook], # ty: ignore[invalid-argument-type] ) async for _ in stream: @@ -3351,7 +3354,7 @@ def counting_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: call_count["value"] += 1 return _combine_updates(updates) - stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer) + stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer) # type: ignore[arg-type, var-annotated] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] async for _ in stream: pass @@ -3376,7 +3379,7 @@ def counting_hook(update: ChatResponseUpdate) -> None: stream = ResponseStream( _generate_updates(3), finalizer=_combine_updates, - transform_hooks=[counting_hook], + transform_hooks=[counting_hook], # ty: ignore[invalid-argument-type] ) await stream.get_final_response() @@ -3389,18 +3392,18 @@ async def test_transform_hook_can_modify_update(self) -> None: def uppercase_hook(update: ChatResponseUpdate) -> ChatResponseUpdate: return ChatResponseUpdate( contents=[Content.from_text((update.text or "").upper())], - role=update.role, + role=cast(Any, update.role), ) stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - transform_hooks=[uppercase_hook], + transform_hooks=[uppercase_hook], # ty: ignore[invalid-argument-type] ) collected: list[str] = [] async for update in stream: - collected.append(update.text or "") + collected.append(update.text or "") # ty: ignore[unresolved-attribute] assert collected == ["UPDATE_0", "UPDATE_1"] @@ -3419,7 +3422,7 @@ def hook_b(update: ChatResponseUpdate) -> ChatResponseUpdate: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - transform_hooks=[hook_a, hook_b], + transform_hooks=[hook_a, hook_b], # ty: ignore[invalid-argument-type] ) async for _ in stream: @@ -3436,12 +3439,12 @@ def none_hook(update: ChatResponseUpdate) -> None: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - transform_hooks=[none_hook], + transform_hooks=[none_hook], # ty: ignore[invalid-argument-type] ) collected: list[str] = [] async for update in stream: - collected.append(update.text or "") + collected.append(update.text or "") # ty: ignore[unresolved-attribute] assert collected == ["update_0", "update_1"] @@ -3466,18 +3469,18 @@ async def test_async_transform_hook(self) -> None: async def async_hook(update: ChatResponseUpdate) -> ChatResponseUpdate: return ChatResponseUpdate( contents=[Content.from_text(f"async_{update.text}")], - role=update.role, + role=cast(Any, update.role), ) stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - transform_hooks=[async_hook], + transform_hooks=[async_hook], # ty: ignore[invalid-argument-type] ) collected: list[str] = [] async for update in stream: - collected.append(update.text or "") + collected.append(update.text or "") # ty: ignore[unresolved-attribute] assert collected == ["async_update_0", "async_update_1"] @@ -3589,12 +3592,12 @@ def add_metadata(response: ChatResponse) -> ChatResponse: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[add_metadata], + result_hooks=[add_metadata], # ty: ignore[invalid-argument-type] ) final = await stream.get_final_response() - assert final.additional_properties["processed"] is True + assert final.additional_properties["processed"] is True # ty: ignore[unresolved-attribute] async def test_result_hook_can_transform_result(self) -> None: """Result hook can transform the final result.""" @@ -3605,12 +3608,12 @@ def wrap_text(response: ChatResponse) -> ChatResponse: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[wrap_text], + result_hooks=[wrap_text], # ty: ignore[invalid-argument-type] ) final = await stream.get_final_response() - assert final.text == "[update_0update_1]" + assert final.text == "[update_0update_1]" # ty: ignore[unresolved-attribute] async def test_multiple_result_hooks_chained(self) -> None: """Multiple result hooks are called in order.""" @@ -3624,12 +3627,12 @@ def add_suffix(response: ChatResponse) -> ChatResponse: stream = ResponseStream( _generate_updates(1), finalizer=_combine_updates, - result_hooks=[add_prefix, add_suffix], + result_hooks=[add_prefix, add_suffix], # ty: ignore[invalid-argument-type] ) final = await stream.get_final_response() - assert final.text == "prefix_update_0_suffix" + assert final.text == "prefix_update_0_suffix" # ty: ignore[unresolved-attribute] async def test_result_hook_returning_none_keeps_previous(self) -> None: """Result hook returning None keeps the previous value.""" @@ -3642,7 +3645,7 @@ def none_hook(response: ChatResponse) -> None: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[none_hook], + result_hooks=[none_hook], # ty: ignore[invalid-argument-type] ) final = await stream.get_final_response() @@ -3672,12 +3675,12 @@ async def async_hook(response: ChatResponse) -> ChatResponse: stream = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[async_hook], + result_hooks=[async_hook], # ty: ignore[invalid-argument-type] ) final = await stream.get_final_response() - assert final.text == "async_update_0update_1" + assert final.text == "async_update_0update_1" # ty: ignore[unresolved-attribute] class TestResponseStreamFinalizer: @@ -3691,7 +3694,7 @@ def capturing_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: received_updates.extend(updates) return ChatResponse(messages=Message("assistant", ["done"])) - stream = ResponseStream(_generate_updates(3), finalizer=capturing_finalizer) + stream = ResponseStream(_generate_updates(3), finalizer=capturing_finalizer) # type: ignore[arg-type, var-annotated] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] await stream.get_final_response() @@ -3716,11 +3719,11 @@ async def async_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: text = "".join(u.text or "" for u in updates) return ChatResponse(messages=Message("assistant", [f"async_{text}"])) - stream = ResponseStream(_generate_updates(2), finalizer=async_finalizer) + stream = ResponseStream(_generate_updates(2), finalizer=async_finalizer) # type: ignore[arg-type, var-annotated] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] final = await stream.get_final_response() - assert final.text == "async_update_0update_1" + assert final.text == "async_update_0update_1" # ty: ignore[unresolved-attribute] async def test_finalized_only_once(self) -> None: """Finalizer is only called once even with multiple get_final_response calls.""" @@ -3730,7 +3733,7 @@ def counting_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: call_count["value"] += 1 return ChatResponse(messages=Message("assistant", ["done"])) - stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer) + stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer) # type: ignore[arg-type, var-annotated] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] await stream.get_final_response() await stream.get_final_response() @@ -3761,7 +3764,7 @@ async def test_map_transforms_updates(self) -> None: def add_prefix(update: ChatResponseUpdate) -> ChatResponseUpdate: return ChatResponseUpdate( contents=[Content.from_text(f"mapped_{update.text}")], - role=update.role, + role=cast(Any, update.role), ) outer = inner.map(add_prefix, _combine_updates) @@ -3793,7 +3796,7 @@ def inner_result_hook(response: ChatResponse) -> ChatResponse: inner = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - result_hooks=[inner_result_hook], + result_hooks=[inner_result_hook], # ty: ignore[invalid-argument-type] ) outer = inner.map(lambda u: u, _combine_updates) @@ -3843,7 +3846,7 @@ async def test_map_with_finalizer(self) -> None: def add_prefix(update: ChatResponseUpdate) -> ChatResponseUpdate: return ChatResponseUpdate( contents=[Content.from_text(f"mapped_{update.text}")], - role=update.role, + role=cast(Any, update.role), ) outer = inner.map(add_prefix, _combine_updates) @@ -3873,7 +3876,7 @@ def outer_hook(update: ChatResponseUpdate) -> ChatResponseUpdate: inner = ResponseStream( _generate_updates(2), finalizer=_combine_updates, - transform_hooks=[inner_hook], + transform_hooks=[inner_hook], # ty: ignore[invalid-argument-type] ) outer = inner.map(lambda u: u, _combine_updates).with_transform_hook(outer_hook) @@ -3908,7 +3911,7 @@ async def test_async_map_transform(self) -> None: async def async_map(update: ChatResponseUpdate) -> ChatResponseUpdate: return ChatResponseUpdate( contents=[Content.from_text(f"async_{update.text}")], - role=update.role, + role=cast(Any, update.role), ) outer = inner.map(async_map, _combine_updates) @@ -3959,12 +3962,12 @@ def result_hook(response: ChatResponse) -> ChatResponse: order.append("result") return response - stream = ResponseStream( + stream = ResponseStream( # type: ignore[var-annotated] _generate_updates(2), - finalizer=finalizer, - transform_hooks=[transform_hook], + finalizer=finalizer, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + transform_hooks=[transform_hook], # ty: ignore[invalid-argument-type] cleanup_hooks=[cleanup_hook], - result_hooks=[result_hook], + result_hooks=[result_hook], # type: ignore[arg-type] # ty: ignore[invalid-argument-type] ) async for _ in stream: @@ -3990,9 +3993,9 @@ def finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: order.append("finalizer") return ChatResponse(messages=Message("assistant", ["done"])) - stream = ResponseStream( + stream = ResponseStream( # type: ignore[var-annotated] _generate_updates(2), - finalizer=finalizer, + finalizer=finalizer, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] cleanup_hooks=[cleanup_hook], ) @@ -4065,7 +4068,7 @@ async def empty_gen() -> AsyncIterable[ChatResponseUpdate]: stream = ResponseStream( empty_gen(), finalizer=_combine_updates, - transform_hooks=[transform_hook], + transform_hooks=[transform_hook], # ty: ignore[invalid-argument-type] ) async for _ in stream: @@ -4114,12 +4117,12 @@ def result(r: ChatResponse) -> ChatResponse: events.append("result") return r - stream = ResponseStream( + stream = ResponseStream( # type: ignore[var-annotated] _generate_updates(1), - finalizer=finalizer, - transform_hooks=[transform], + finalizer=finalizer, # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] + transform_hooks=[transform], # ty: ignore[invalid-argument-type] cleanup_hooks=[cleanup], - result_hooks=[result], + result_hooks=[result], # type: ignore[arg-type] # ty: ignore[invalid-argument-type] ) await stream.get_final_response() diff --git a/python/packages/core/tests/test_security.py b/python/packages/core/tests/test_security.py index 0a638f5883c..a7680cb1adf 100644 --- a/python/packages/core/tests/test_security.py +++ b/python/packages/core/tests/test_security.py @@ -93,8 +93,8 @@ def test_security_classes_are_marked_experimental(self): ] for security_class in security_classes: - assert security_class.__feature_stage__ == "experimental" - assert security_class.__feature_id__ == ExperimentalFeature.FIDES.value + assert security_class.__feature_stage__ == "experimental" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert security_class.__feature_id__ == ExperimentalFeature.FIDES.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] class TestCombineLabels: @@ -327,7 +327,7 @@ async def trusted_fn(arg: str) -> str: additional_properties={"source_integrity": "trusted"}, ) - args = trusted_function.args_schema(arg="test") + args = trusted_function.args_schema(arg="test") # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=trusted_function, arguments=args) async def next_fn(): @@ -379,7 +379,7 @@ async def process_fn(data: dict) -> str: ) # Create argument that contains untrusted label - args = trusted_function.args_schema( + args = trusted_function.args_schema( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] data={"content": "test", "security_label": {"integrity": "untrusted", "confidentiality": "public"}} ) @@ -422,7 +422,7 @@ async def process_fn(var_ref: dict) -> str: # Pass the VariableReferenceContent as an argument context = FunctionInvocationContext( function=trusted_function, - arguments=trusted_function.args_schema(var_ref={"test": "value"}), # Regular dict + arguments=trusted_function.args_schema(var_ref={"test": "value"}), # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # Regular dict ) # But also pass the actual VariableReferenceContent in kwargs context.kwargs = {"var_ref_obj": var_ref} @@ -510,7 +510,7 @@ async def mock_fn(arg: str) -> str: fn=mock_fn, name="allowed_function", description="Allowed function", args_schema=MockArgs ) - args = allowed_function.args_schema(arg="test") + args = allowed_function.args_schema(arg="test") # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=allowed_function, arguments=args) # Set untrusted context label (policy enforcement uses context_label) @@ -549,7 +549,7 @@ async def next_fn() -> None: assert context.result.type == "function_approval_request" assert context.result.additional_properties["policy_violation"] is True assert context.result.additional_properties["violation_type"] == "untrusted_context" - assert context.result.function_call.call_id == "call-untrusted" + assert context.result.function_call.call_id == "call-untrusted" # type: ignore[union-attr] async def test_confidentiality_violation_requests_policy_approval(self, mock_function): """Test confidentiality violations reuse the policy approval path.""" @@ -653,7 +653,7 @@ async def test_policy_violation_approval_preserves_type_through_auto_invoke(self label_tracker = LabelTrackingFunctionMiddleware(auto_hide_untrusted=False) # Taint the context label so the policy enforcer sees UNTRUSTED label_tracker._context_label = ContentLabel(integrity=IntegrityLabel.UNTRUSTED) - label_tracker._initialized = True + label_tracker._initialized = True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] policy = PolicyEnforcementFunctionMiddleware(approval_on_violation=True) pipeline = FunctionMiddlewarePipeline(label_tracker, policy) @@ -728,7 +728,7 @@ async def next_fn(): item = context.result[0] assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - parsed = json.loads(item.text) + parsed = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed.get("type") == "variable_reference" assert parsed["variable_id"].startswith("var_") @@ -756,7 +756,7 @@ async def trusted_fn(value: str = "default") -> str: additional_properties={"source_integrity": "trusted"}, ) - args = trusted_function.args_schema() + args = trusted_function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=trusted_function, arguments=args) async def next_fn(): @@ -917,9 +917,9 @@ def test_create_config_with_options(self): label_tracker = middleware[0] policy_enforcer = middleware[1] - assert label_tracker.auto_hide_untrusted is True - assert "fetch_data" in policy_enforcer.allow_untrusted_tools - assert "search" in policy_enforcer.allow_untrusted_tools + assert label_tracker.auto_hide_untrusted is True # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert "fetch_data" in policy_enforcer.allow_untrusted_tools # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert "search" in policy_enforcer.allow_untrusted_tools # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] def test_get_tools_returns_security_tools(self): """Test that get_tools returns quarantined_llm and inspect_variable.""" @@ -952,7 +952,7 @@ def test_inspect_variable_uses_generic_approval_mode(self): inspect_variable = next(tool for tool in get_security_tools() if tool.name == "inspect_variable") assert inspect_variable.approval_mode == "never_require" - assert "requires_approval" not in inspect_variable.additional_properties + assert "requires_approval" not in inspect_variable.additional_properties # type: ignore[operator] # pyrefly: ignore[not-iterable] # ty: ignore[unsupported-operator] class TestGetSecurityTools: @@ -1253,10 +1253,10 @@ async def untrusted_fn(value: str = "default") -> str: current_context = None async def next_fn(): - current_context.result = [Content.from_text("result")] + current_context.result = [Content.from_text("result")] # type: ignore[attr-defined, union-attr] # ty: ignore[invalid-assignment] # First call: trusted function (TRUSTED) - context1 = FunctionInvocationContext(function=trusted_function, arguments=trusted_function.args_schema()) + context1 = FunctionInvocationContext(function=trusted_function, arguments=trusted_function.args_schema()) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] current_context = context1 await middleware.process(context1, next_fn) @@ -1265,7 +1265,7 @@ async def next_fn(): assert middleware.get_context_label().integrity == IntegrityLabel.TRUSTED # Second call: untrusted function (UNTRUSTED) - context2 = FunctionInvocationContext(function=untrusted_function, arguments=untrusted_function.args_schema()) + context2 = FunctionInvocationContext(function=untrusted_function, arguments=untrusted_function.args_schema()) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] current_context = context2 await middleware.process(context2, next_fn) @@ -1274,7 +1274,7 @@ async def next_fn(): assert middleware.get_context_label().integrity == IntegrityLabel.UNTRUSTED # Third call: trusted function again - context3 = FunctionInvocationContext(function=trusted_function, arguments=trusted_function.args_schema()) + context3 = FunctionInvocationContext(function=trusted_function, arguments=trusted_function.args_schema()) # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] current_context = context3 await middleware.process(context3, next_fn) @@ -1351,7 +1351,7 @@ async def mock_fn(arg: str = "default") -> str: args_schema=MockArgs, ) - args = allowed_function.args_schema() + args = allowed_function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=allowed_function, arguments=args) context.metadata["context_label"] = label_middleware.get_context_label() @@ -1551,8 +1551,8 @@ def test_quarantined_llm_declares_source_integrity(self): from agent_framework.security import get_security_tools q_llm = next(tool for tool in get_security_tools() if tool.name == "quarantined_llm") - assert q_llm.additional_properties.get("source_integrity") == "untrusted" - assert q_llm.additional_properties.get("accepts_untrusted") is True + assert q_llm.additional_properties.get("source_integrity") == "untrusted" # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert q_llm.additional_properties.get("accepts_untrusted") is True # type: ignore[union-attr] # ty: ignore[unresolved-attribute] class TestQuarantineClient: @@ -1573,7 +1573,7 @@ async def get_response(self, messages, **kwargs): pass mock_client = MockClient() - set_quarantine_client(mock_client) + set_quarantine_client(mock_client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert get_quarantine_client() is mock_client @@ -1596,7 +1596,7 @@ async def get_response(self, messages, **kwargs): mock_client = MockClient() # Create config with quarantine client - config = SecureAgentConfig(quarantine_chat_client=mock_client) + config = SecureAgentConfig(quarantine_chat_client=mock_client) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # Should have set the global client assert get_quarantine_client() is mock_client @@ -1864,7 +1864,7 @@ async def next_fn(): # First item should be visible (trusted) item0 = context.result[0] assert isinstance(item0, Content) - data0 = json.loads(item0.text) + data0 = json.loads(item0.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert data0["id"] == 1 assert data0["content"] == "trusted content" @@ -1872,7 +1872,7 @@ async def next_fn(): item1 = context.result[1] assert isinstance(item1, Content) assert item1.additional_properties.get("_variable_reference") is True - parsed1 = json.loads(item1.text) + parsed1 = json.loads(item1.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed1.get("type") == "variable_reference" assert parsed1["security_label"]["integrity"] == "untrusted" @@ -1935,7 +1935,7 @@ async def next_fn(): for item in context.result: assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - parsed = json.loads(item.text) + parsed = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed.get("type") == "variable_reference" @pytest.mark.asyncio @@ -1957,7 +1957,7 @@ async def untrusted_fn() -> list: # No source_integrity = defaults to UNTRUSTED ) - args = untrusted_function.args_schema() + args = untrusted_function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=untrusted_function, arguments=args) async def next_fn(): @@ -1976,13 +1976,13 @@ async def next_fn(): for item in context.result: assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - parsed = json.loads(item.text) + parsed = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed.get("type") == "variable_reference" assert parsed["security_label"]["integrity"] == "untrusted" # The call/result label should be UNTRUSTED label = context.metadata.get("result_label") - assert label.integrity == IntegrityLabel.UNTRUSTED + assert label.integrity == IntegrityLabel.UNTRUSTED # type: ignore[union-attr] # ty: ignore[unresolved-attribute] @pytest.mark.asyncio async def test_nested_json_in_content_item(self, middleware, mock_function): @@ -2014,7 +2014,7 @@ async def next_fn(): item = context.result[0] assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - parsed = json.loads(item.text) + parsed = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed.get("type") == "variable_reference" @pytest.mark.asyncio @@ -2040,8 +2040,8 @@ async def next_fn(): # Combined label should be UNTRUSTED (most restrictive integrity) # and PRIVATE (most restrictive confidentiality) label = context.metadata.get("result_label") - assert label.integrity == IntegrityLabel.UNTRUSTED - assert label.confidentiality == ConfidentialityLabel.PRIVATE + assert label.integrity == IntegrityLabel.UNTRUSTED # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert label.confidentiality == ConfidentialityLabel.PRIVATE # type: ignore[union-attr] # ty: ignore[unresolved-attribute] @pytest.mark.asyncio async def test_hidden_items_stored_in_variable_store(self, middleware, mock_function): @@ -2064,7 +2064,7 @@ async def next_fn(): item = context.result[0] assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - var_ref = json.loads(item.text) + var_ref = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert var_ref.get("type") == "variable_reference" # Retrieve from store @@ -2100,7 +2100,7 @@ async def next_fn(): assert len(context.result) == 1 item = context.result[0] assert isinstance(item, Content) - data = json.loads(item.text) + data = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert data["data"] == "untrusted but visible" @@ -2143,7 +2143,7 @@ async def fn(data: dict) -> str: ) # Input has an untrusted label embedded in the argument - args = function.args_schema( + args = function.args_schema( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] data={"content": "test", "security_label": {"integrity": "untrusted", "confidentiality": "public"}} ) context = FunctionInvocationContext(function=function, arguments=args) @@ -2179,7 +2179,7 @@ async def fn() -> list: additional_properties={"source_integrity": "trusted"}, ) - args = function.args_schema() + args = function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=function, arguments=args) async def next_fn(): @@ -2219,7 +2219,7 @@ async def fn(data: dict) -> str: ) # Input has an untrusted label - args = function.args_schema( + args = function.args_schema( # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] data={"content": "test", "security_label": {"integrity": "untrusted", "confidentiality": "public"}} ) context = FunctionInvocationContext(function=function, arguments=args) @@ -2236,7 +2236,7 @@ async def next_fn(): item = context.result[0] assert isinstance(item, Content) assert item.additional_properties.get("_variable_reference") is True - parsed = json.loads(item.text) + parsed = json.loads(item.text) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] assert parsed.get("type") == "variable_reference" @pytest.mark.asyncio @@ -2260,7 +2260,7 @@ async def fn(arg: str = "default") -> str: args_schema=Args, ) - args = function.args_schema() + args = function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=function, arguments=args) async def next_fn(): @@ -2437,7 +2437,7 @@ async def mock_fn(arg: str = "default") -> str: }, ) - args = function.args_schema() + args = function.args_schema() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] context = FunctionInvocationContext(function=function, arguments=args) context.metadata["context_label"] = label_middleware.get_context_label() diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index c9004f234b1..ccb1e9425bf 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -243,7 +243,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert resumed_output is not None # Verify the restored executor's session state was restored - restored_session_obj = restored_exec_a._session # type: ignore[reportPrivateUsage] + restored_session_obj = restored_exec_a._session # pyright: ignore[reportPrivateUsage] assert restored_session_obj is not None assert restored_session_obj.session_id == initial_session.session_id @@ -269,7 +269,7 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: Message(role="user", contents=["Cached user message"]), Message(role="assistant", contents=["Cached assistant response"]), ] - executor._cache = list(cache_messages) # type: ignore[reportPrivateUsage] + executor._cache = list(cache_messages) # pyright: ignore[reportPrivateUsage] # Snapshot the state state = await executor.on_checkpoint_save() @@ -289,20 +289,20 @@ async def test_agent_executor_save_and_restore_state_directly() -> None: new_executor = AgentExecutor(new_agent, session=new_session) # Verify new executor starts empty - assert len(new_executor._cache) == 0 # type: ignore[reportPrivateUsage] + assert len(new_executor._cache) == 0 # pyright: ignore[reportPrivateUsage] assert len(new_session.state) == 0 # Restore state await new_executor.on_checkpoint_restore(state) # Verify cache is restored - restored_cache = new_executor._cache # type: ignore[reportPrivateUsage] + restored_cache = new_executor._cache # pyright: ignore[reportPrivateUsage] assert len(restored_cache) == len(cache_messages) assert restored_cache[0].text == "Cached user message" assert restored_cache[1].text == "Cached assistant response" # Verify session was restored with correct session_id - restored_session = new_executor._session # type: ignore[reportPrivateUsage] + restored_session = new_executor._session # pyright: ignore[reportPrivateUsage] assert restored_session.session_id == session.session_id @@ -368,7 +368,7 @@ async def test_agent_executor_workflow_with_non_copyable_raw_representation() -> agent_a = _AgentWithRawRepr(raw=raw, id="a", name="AgentA") agent_b = _CountingAgent(id="b", name="AgentB") - exec_a = AgentExecutor(agent_a, id="exec_a") + exec_a = AgentExecutor(agent_a, id="exec_a") # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] exec_b = AgentExecutor(agent_b, id="exec_b") workflow = WorkflowBuilder(start_executor=exec_a).add_edge(exec_a, exec_b).build() @@ -429,7 +429,7 @@ def run( ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: captured: list[Message] = [] if messages: - for m in messages: # type: ignore[union-attr] + for m in messages: # type: ignore[union-attr] # ty: ignore[not-iterable] if isinstance(m, Message): captured.append(m) elif isinstance(m, str): @@ -696,7 +696,7 @@ async def test_resolve_executor_kwargs_empty_per_executor_does_not_fallback_to_g # Per-executor entry for exec_a is empty, but global has values. # The empty dict should be honoured (no fallback to global). - resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}} + resolved = {"exec_a": {}, GLOBAL_KWARGS_KEY: {"global_key": "global_val"}} # type: ignore[var-annotated] result = executor._resolve_executor_kwargs(resolved) # pyright: ignore[reportPrivateUsage] assert result == {} diff --git a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py index 9f17af9e4eb..17b9332a28f 100644 --- a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py +++ b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py @@ -259,7 +259,7 @@ async def _stream_response(self) -> AsyncIterable[ChatResponseUpdate]: @executor(id="test_executor") -async def test_executor(agent_executor_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: +async def test_executor(agent_executor_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(agent_executor_response.agent_response.text) diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 633ba1072cd..8ac07448097 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -33,16 +33,16 @@ def run( session: AgentSession | None = ..., **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( + def run( # type: ignore[empty-body] self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... # ty: ignore[empty-body] - def create_session(self, **kwargs: Any) -> AgentSession: + def create_session(self, **kwargs: Any) -> AgentSession: # type: ignore[empty-body] # ty: ignore[empty-body] """Creates a new conversation session for the agent.""" ... @@ -53,27 +53,27 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession def test_resolve_agent_id_with_name() -> None: """Test that resolve_agent_id returns name when agent has a name.""" agent = MockAgent(agent_id="agent-123", name="MyAgent") - result = resolve_agent_id(agent) + result = resolve_agent_id(agent) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert result == "MyAgent" def test_resolve_agent_id_without_name() -> None: """Test that resolve_agent_id returns id when agent has no name.""" agent = MockAgent(agent_id="agent-456", name=None) - result = resolve_agent_id(agent) + result = resolve_agent_id(agent) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert result == "agent-456" def test_resolve_agent_id_with_empty_name() -> None: """Test that resolve_agent_id returns id when agent has empty string name.""" agent = MockAgent(agent_id="agent-789", name="") - result = resolve_agent_id(agent) + result = resolve_agent_id(agent) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert result == "agent-789" def test_resolve_agent_id_prefers_name_over_id() -> None: """Test that resolve_agent_id prefers name over id when both are set.""" agent = MockAgent(agent_id="agent-abc", name="PreferredName") - result = resolve_agent_id(agent) + result = resolve_agent_id(agent) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] assert result == "PreferredName" assert result != "agent-abc" diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index e395655afaf..29fbc1554b6 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -79,8 +79,8 @@ def test_workflow_checkpoint_custom_values(): workflow_name="test-workflow-456", graph_signature_hash="test-hash-456", timestamp=custom_timestamp, - messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test - pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test + messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type, list-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type, dict-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test state={"key": "value"}, iteration_count=5, metadata={"test": True}, @@ -104,7 +104,7 @@ def test_workflow_checkpoint_to_dict(): checkpoint_id="test-id", workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test + messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type, list-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test state={"key": "value"}, iteration_count=5, ) @@ -162,8 +162,8 @@ async def test_memory_checkpoint_storage_save_and_load(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "hello"}]}, # type: ignore[arg-type] # raw dict for serialization test - pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test + messages={"executor1": [{"data": "hello"}]}, # type: ignore[arg-type, list-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type, dict-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test ) # Save checkpoint @@ -301,7 +301,7 @@ async def process(self, message: str, ctx: WorkflowContext[str]) -> None: class FinishExecutor(Executor): @handler - async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(message + "-done") storage = InMemoryCheckpointStorage() @@ -596,7 +596,7 @@ async def test_memory_checkpoint_storage_roundtrip_pending_request_info_events() checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - pending_request_info_events=pending_events, + pending_request_info_events=pending_events, # type: ignore[arg-type] ) await storage.save(checkpoint) @@ -777,9 +777,9 @@ async def test_file_checkpoint_storage_save_and_load(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test + messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, # type: ignore[arg-type, list-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test state={"key": "value"}, - pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type, dict-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test ) # Save checkpoint @@ -905,9 +905,9 @@ async def test_file_checkpoint_storage_json_serialization(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test + messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, # type: ignore[arg-type, list-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None}, - pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type, dict-item] # ty: ignore[invalid-argument-type] # raw dict for serialization test ) # Save and load @@ -1272,7 +1272,7 @@ async def test_file_checkpoint_storage_roundtrip_pending_request_info_events(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - pending_request_info_events=pending_events, + pending_request_info_events=pending_events, # type: ignore[arg-type] ) await storage.save(checkpoint) diff --git a/python/packages/core/tests/workflow/test_checkpoint_validation.py b/python/packages/core/tests/workflow/test_checkpoint_validation.py index a9c748a3249..b0f941a8e9b 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_validation.py +++ b/python/packages/core/tests/workflow/test_checkpoint_validation.py @@ -23,7 +23,7 @@ async def run(self, message: str, ctx: WorkflowContext[str]) -> None: class FinishExecutor(Executor): @handler - async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(message) @@ -95,7 +95,7 @@ async def run(self, message: str, ctx: WorkflowContext[str]) -> None: class SubFinishExecutor(Executor): @handler - async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + async def finish(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(message) diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index df742b76f6a..d20fad32afd 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -214,7 +214,7 @@ async def test_executor_completed_event_includes_yielded_outputs(): class YieldOnlyExecutor(Executor): @handler - async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: + async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(text.upper()) executor = YieldOnlyExecutor(id="yielder") @@ -306,7 +306,7 @@ class SingleOutputExecutor(Executor): async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: pass - executor = SingleOutputExecutor(id="single_output") + executor = SingleOutputExecutor(id="single_output") # type: ignore[assignment] assert int in executor.output_types assert len(executor.output_types) == 1 @@ -316,7 +316,7 @@ class UnionOutputExecutor(Executor): async def handle(self, text: str, ctx: WorkflowContext[int | str]) -> None: pass - executor = UnionOutputExecutor(id="union_output") + executor = UnionOutputExecutor(id="union_output") # type: ignore[assignment] assert int in executor.output_types assert str in executor.output_types assert len(executor.output_types) == 2 @@ -331,7 +331,7 @@ async def handle_string(self, text: str, ctx: WorkflowContext[int]) -> None: async def handle_number(self, num: int, ctx: WorkflowContext[bool]) -> None: pass - executor = MultiHandlerExecutor(id="multi_handler") + executor = MultiHandlerExecutor(id="multi_handler") # type: ignore[assignment] assert int in executor.output_types assert bool in executor.output_types assert len(executor.output_types) == 2 @@ -355,7 +355,7 @@ class WorkflowOutputExecutor(Executor): async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: pass - executor = WorkflowOutputExecutor(id="workflow_output") + executor = WorkflowOutputExecutor(id="workflow_output") # type: ignore[assignment] assert str in executor.workflow_output_types assert len(executor.workflow_output_types) == 1 @@ -365,7 +365,7 @@ class UnionWorkflowOutputExecutor(Executor): async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: pass - executor = UnionWorkflowOutputExecutor(id="union_workflow_output") + executor = UnionWorkflowOutputExecutor(id="union_workflow_output") # type: ignore[assignment] assert str in executor.workflow_output_types assert bool in executor.workflow_output_types assert len(executor.workflow_output_types) == 2 @@ -380,7 +380,7 @@ async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: pass - executor = MultiHandlerWorkflowExecutor(id="multi_workflow") + executor = MultiHandlerWorkflowExecutor(id="multi_workflow") # type: ignore[assignment] assert str in executor.workflow_output_types assert float in executor.workflow_output_types assert len(executor.workflow_output_types) == 2 @@ -388,10 +388,10 @@ async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> No # Test executor with Never for message output (only workflow output) class YieldOnlyExecutor(Executor): @handler - async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: + async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] pass - executor = YieldOnlyExecutor(id="yield_only") + executor = YieldOnlyExecutor(id="yield_only") # type: ignore[assignment] assert str in executor.workflow_output_types assert len(executor.workflow_output_types) == 1 # Should have no message output types @@ -604,7 +604,7 @@ async def handle(self, message: str, ctx: WorkflowContext[str]) -> None: # Handler spec should have int as output type (explicit) handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage] - assert handler_func._handler_spec["output_types"] == [int] # pyright: ignore[reportFunctionMemberAccess] + assert handler_func._handler_spec["output_types"] == [int] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # pyright: ignore[reportFunctionMemberAccess] # Executor output_types property should reflect explicit type assert int in exec_instance.output_types @@ -627,7 +627,7 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: # Output type should be list (explicit) handler_func = exec_instance._handlers[dict] # pyright: ignore[reportPrivateUsage] - assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess] + assert handler_func._handler_spec["output_types"] == [list] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # pyright: ignore[reportFunctionMemberAccess] # Verify can_handle assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock")) @@ -820,7 +820,7 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: # Handler spec should have bool as workflow_output_type (explicit) handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage] - assert handler_func._handler_spec["workflow_output_types"] == [bool] # pyright: ignore[reportFunctionMemberAccess] + assert handler_func._handler_spec["workflow_output_types"] == [bool] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # pyright: ignore[reportFunctionMemberAccess] # Executor workflow_output_types property should reflect explicit type assert bool in exec_instance.workflow_output_types @@ -949,7 +949,7 @@ def test_handler_rejects_bounded_typevar_in_message_annotation(): class BoundedGenericExecutor(Executor, Generic[_BT]): @handler async def process(self, message: _BT, ctx: WorkflowContext) -> None: - await ctx.send_message(message) + await ctx.send_message(message) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] def test_handler_allows_concrete_types(): @@ -972,7 +972,7 @@ class GenericWithExplicit(Executor, Generic[_T]): async def echo(self, message, ctx: WorkflowContext) -> None: pass - exec_instance = GenericWithExplicit(id="explicit") + exec_instance = GenericWithExplicit(id="explicit") # type: ignore[var-annotated] assert str in exec_instance.input_types diff --git a/python/packages/core/tests/workflow/test_executor_future.py b/python/packages/core/tests/workflow/test_executor_future.py index cb0c5c9f58f..81818caed31 100644 --- a/python/packages/core/tests/workflow/test_executor_future.py +++ b/python/packages/core/tests/workflow/test_executor_future.py @@ -120,5 +120,5 @@ def test_handler_unresolvable_annotation_raises(self): class Bad(Executor): # pyright: ignore[reportUnusedClass] @handler # pyright: ignore[reportUnknownArgumentType] - async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 # type: ignore[name-defined] + async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # type: ignore[name-defined] # ty: ignore[unresolved-reference] # noqa: F821 pass diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index 27b3dc4019f..d7f9c3b245c 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -158,7 +158,7 @@ class _CaptureFullConversation(Executor): """Captures AgentExecutorResponse.full_conversation and completes the workflow.""" @handler - async def capture(self, response: AgentExecutorResponse, ctx: WorkflowContext[Never, dict[str, Any]]) -> None: + async def capture(self, response: AgentExecutorResponse, ctx: WorkflowContext[Never, dict[str, Any]]) -> None: # type: ignore[valid-type] full = response.full_conversation # The AgentExecutor contract guarantees full_conversation is populated. assert full is not None @@ -232,7 +232,7 @@ def run( # Normalize and record messages for verification norm: list[Message] = [] if messages: - for m in messages: # type: ignore[iteration-over-optional] + for m in messages: # type: ignore[iteration-over-optional, union-attr] # ty: ignore[not-iterable] if isinstance(m, Message): norm.append(m) elif isinstance(m, str): diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index b45a6677227..8b0316aab56 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -119,7 +119,7 @@ async def multi_output(text: str, ctx: WorkflowContext[str | int]) -> None: # Test union types for workflow outputs too @executor - async def multi_workflow_output(data: str, ctx: WorkflowContext[Never, str | int | bool]) -> None: + async def multi_workflow_output(data: str, ctx: WorkflowContext[Never, str | int | bool]) -> None: # type: ignore[valid-type] if data.isdigit(): await ctx.yield_output(int(data)) elif data.lower() in ("true", "false"): @@ -414,7 +414,7 @@ def valid_sync_with_ctx(data: int, ctx: WorkflowContext[str]): assert int in func_exec2._handlers # pyright: ignore[reportPrivateUsage] # Sync function with missing type annotation should still fail - def no_annotation(data): # type: ignore # pyright: ignore[reportUnknownVariableType] + def no_annotation(data): # pyright: ignore[reportUnknownVariableType] # type: ignore return data # pyright: ignore[reportUnknownVariableType] with pytest.raises(ValueError, match="type annotation for the message"): @@ -485,7 +485,7 @@ async def test_sync_function_thread_execution(self): @executor def blocking_function(data: str): nonlocal execution_thread_id - execution_thread_id = threading.get_ident() + execution_thread_id = threading.get_ident() # type: ignore[assignment] # Simulate some CPU-bound work time.sleep(0.01) # Small sleep to verify thread execution return data.upper() @@ -522,7 +522,7 @@ def test_executor_rejects_classmethod(self): class Example: # pyright: ignore[reportUnusedClass] @executor @classmethod - async def bad_handler(cls, data: str) -> str: + async def bad_handler(cls, data: str) -> str: # type: ignore[operator] return data.upper() assert "cannot be used with @classmethod" in str(exc_info.value) @@ -675,7 +675,7 @@ def test_executor_partial_explicit_types(self): async def process_input(message: str, ctx: WorkflowContext[int]) -> None: pass - assert bytes in process_input._handlers # Explicit # pyright: ignore[reportPrivateUsage] + assert bytes in process_input._handlers # pyright: ignore[reportPrivateUsage] # Explicit assert int in process_input.output_types # Introspected # Only explicit output_type, introspect input_type @@ -683,7 +683,7 @@ async def process_input(message: str, ctx: WorkflowContext[int]) -> None: async def process_output(message: str, ctx: WorkflowContext[int]) -> None: pass - assert str in process_output._handlers # Introspected # pyright: ignore[reportPrivateUsage] + assert str in process_output._handlers # pyright: ignore[reportPrivateUsage] # Introspected assert float in process_output.output_types # Explicit assert int not in process_output.output_types # Not introspected when explicit provided @@ -748,7 +748,7 @@ def test_executor_explicit_union_types_via_typing_union(self): """Test that Union[] syntax also works for explicit types.""" from typing import Union - @executor(input=Union[str, int], output=Union[bool, float]) + @executor(input=Union[str, int], output=Union[bool, float]) # type: ignore[call-overload] async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass @@ -962,7 +962,7 @@ def test_function_executor_rejects_bounded_typevar_in_message_annotation(): """Test that FunctionExecutor raises ValueError for a bounded TypeVar in message annotation.""" async def process(message: _FBT, ctx: WorkflowContext) -> None: - await ctx.send_message(message) + await ctx.send_message(message) # type: ignore[arg-type] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] with pytest.raises(ValueError, match="unresolved TypeVar"): FunctionExecutor(process, id="bounded") diff --git a/python/packages/core/tests/workflow/test_function_executor_future.py b/python/packages/core/tests/workflow/test_function_executor_future.py index f98a1502727..8b6f9170dda 100644 --- a/python/packages/core/tests/workflow/test_function_executor_future.py +++ b/python/packages/core/tests/workflow/test_function_executor_future.py @@ -54,5 +54,5 @@ def test_handler_unresolvable_annotation_raises(self): ) -async def _func_with_bad_annotation(message: NonExistentType, ctx: WorkflowContext[int]) -> None: # noqa: F821 # type: ignore[name-defined] +async def _func_with_bad_annotation(message: NonExistentType, ctx: WorkflowContext[int]) -> None: # type: ignore[name-defined] # ty: ignore[unresolved-reference] # noqa: F821 pass diff --git a/python/packages/core/tests/workflow/test_functional_workflow.py b/python/packages/core/tests/workflow/test_functional_workflow.py index d52c5497f9f..4ae4069061c 100644 --- a/python/packages/core/tests/workflow/test_functional_workflow.py +++ b/python/packages/core/tests/workflow/test_functional_workflow.py @@ -261,8 +261,8 @@ async def test_untyped_ctx_parameter(self): """ctx is injected by parameter name even without a RunContext annotation.""" @workflow # pyright: ignore[reportUnknownArgumentType] - async def review_wf(doc: str, ctx) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] - feedback: str = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + async def review_wf(doc: str, ctx) -> str: # pyright: ignore[reportMissingParameterType, reportUnknownParameterType] + feedback: str = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] return f"Final: {feedback}" result1 = await review_wf.run("my doc") @@ -412,7 +412,7 @@ async def test_streaming_context_reports_streaming(self): @workflow async def wf(x: int, ctx: RunContext) -> int: nonlocal streaming_flag - streaming_flag = ctx.is_streaming() + streaming_flag = ctx.is_streaming() # type: ignore[assignment] return x stream = wf.run(1, stream=True) @@ -984,7 +984,7 @@ async def wf(x: int) -> int: async def test_step_sync_function_raises(self): with pytest.raises(TypeError, match="async functions"): - @step # pyright: ignore[reportArgumentType] + @step # type: ignore[arg-type, call-overload] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type] # pyright: ignore[reportArgumentType] def not_async(x: int) -> int: # pyright: ignore[reportUnusedFunction] return x @@ -1185,7 +1185,7 @@ async def test_get_run_context_inside_workflow(self): @step async def capture_ctx(x: int) -> int: nonlocal captured_ctx - captured_ctx = get_run_context() + captured_ctx = get_run_context() # type: ignore[assignment] return x @workflow @@ -1719,7 +1719,7 @@ def test_public_symbols_are_marked_experimental(self) -> None: ] for symbol in symbols: - assert symbol.__feature_stage__ == "experimental" - assert symbol.__feature_id__ == ExperimentalFeature.FUNCTIONAL_WORKFLOWS.value + assert symbol.__feature_stage__ == "experimental" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + assert symbol.__feature_id__ == ExperimentalFeature.FUNCTIONAL_WORKFLOWS.value # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert symbol.__doc__ is not None assert ".. warning:: Experimental" in symbol.__doc__ diff --git a/python/packages/core/tests/workflow/test_output_designation.py b/python/packages/core/tests/workflow/test_output_designation.py index cd4b23ee21e..8a6b62a1a56 100644 --- a/python/packages/core/tests/workflow/test_output_designation.py +++ b/python/packages/core/tests/workflow/test_output_designation.py @@ -61,7 +61,7 @@ def test_designation_is_frozen() -> None: designation = OutputDesignation(outputs=frozenset({"alpha"})) with pytest.raises(FrozenInstanceError): - designation.outputs = frozenset({"beta"}) # type: ignore[misc] + designation.outputs = frozenset({"beta"}) # type: ignore[misc] # ty: ignore[invalid-assignment] # --------------------------------------------------------------------------- @@ -70,12 +70,12 @@ def test_designation_is_frozen() -> None: @executor -async def _emit_one(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: +async def _emit_one(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("hello") @executor -async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: +async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("downstream") diff --git a/python/packages/core/tests/workflow/test_output_executors_contract.py b/python/packages/core/tests/workflow/test_output_executors_contract.py index 31f4f946b75..74df7d00cd5 100644 --- a/python/packages/core/tests/workflow/test_output_executors_contract.py +++ b/python/packages/core/tests/workflow/test_output_executors_contract.py @@ -20,7 +20,7 @@ @executor -async def _emit_one(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: +async def _emit_one(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("hello") @@ -31,7 +31,7 @@ async def _start(messages: list[Message], ctx: WorkflowContext[str, str]) -> Non @executor -async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: +async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("from-downstream") @@ -223,7 +223,7 @@ async def test_intermediate_output_from_all_routes_every_yield_to_intermediate() def test_output_from_all_other_is_rejected() -> None: """The all-other literal is only valid for intermediate output selection.""" with pytest.raises(ValueError, match="output_from.*all_other"): - WorkflowBuilder(start_executor=_emit_one, output_from="all_other") # type: ignore[arg-type] + WorkflowBuilder(start_executor=_emit_one, output_from="all_other") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] @pytest.mark.parametrize( diff --git a/python/packages/core/tests/workflow/test_request_info_and_response.py b/python/packages/core/tests/workflow/test_request_info_and_response.py index 05a7ed1ec54..399a300bfe3 100644 --- a/python/packages/core/tests/workflow/test_request_info_and_response.py +++ b/python/packages/core/tests/workflow/test_request_info_and_response.py @@ -70,10 +70,10 @@ async def handle_approval_response( self.approval_received = True if approved: - self.final_result = f"Operation approved: {original_request.prompt}" + self.final_result = f"Operation approved: {original_request.prompt}" # type: ignore[assignment] await ctx.send_message(f"APPROVED: {original_request.context}") else: - self.final_result = "Operation denied by user" + self.final_result = "Operation denied by user" # type: ignore[assignment] await ctx.send_message("DENIED: Operation was not approved") diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index cfde71b4817..c53044d37a8 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -47,9 +47,9 @@ async def test_handler(self: Any, original_request: str, response: int, ctx: Wor # Check the spec attributes spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] - assert spec["name"] == "test_handler" - assert spec["response_type"] is int - assert spec["request_type"] is str + assert spec["name"] == "test_handler" # ty: ignore[not-subscriptable] + assert spec["response_type"] is int # ty: ignore[not-subscriptable] + assert spec["request_type"] is str # ty: ignore[not-subscriptable] def test_response_handler_with_workflow_context_types(self): """Test response handler with different WorkflowContext type parameters.""" @@ -60,7 +60,7 @@ async def handler_with_output_types( ) -> None: pass - spec = handler_with_output_types._response_handler_spec # type: ignore[reportAttributeAccessIssue] + spec = handler_with_output_types._response_handler_spec # type: ignore[attr-defined, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] assert "output_types" in spec assert "workflow_output_types" in spec @@ -173,8 +173,8 @@ async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: @response_handler async def handle_response(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: - self.handled_request = original_request - self.handled_response = response + self.handled_request = original_request # type: ignore[assignment] + self.handled_response = response # type: ignore[assignment] executor = TestExecutor() @@ -182,7 +182,7 @@ async def handle_response(self, original_request: str, response: int, ctx: Workf response_handler_func = executor._response_handlers[(str, int)] # type: ignore[reportAttributeAccessIssue] # Create a mock context - we'll just use None since the handler doesn't use it - await response_handler_func("test_request", 42, None) # type: ignore[reportArgumentType] + await response_handler_func("test_request", 42, None) # type: ignore[arg-type, reportArgumentType] # ty: ignore[invalid-argument-type] assert executor.handled_request == "test_request" assert executor.handled_response == 42 @@ -542,13 +542,13 @@ async def handle_list_str_float( executor = TestExecutor() # Test that wrong combinations don't match - assert executor._find_response_handler("test", 3.14) is None # pyright: ignore[reportPrivateUsage] # str request, float response - no handler - assert executor._find_response_handler(["test"], 42) is None # pyright: ignore[reportPrivateUsage] # list request, int response - no handler - assert executor._find_response_handler(42, "test") is None # pyright: ignore[reportPrivateUsage] # int request, str response - no handler + assert executor._find_response_handler("test", 3.14) is None # pyright: ignore[reportPrivateUsage] # str request, float response - no handler + assert executor._find_response_handler(["test"], 42) is None # pyright: ignore[reportPrivateUsage] # list request, int response - no handler + assert executor._find_response_handler(42, "test") is None # pyright: ignore[reportPrivateUsage] # int request, str response - no handler # Test that correct combinations do match - assert executor._find_response_handler("test", 42) is not None # pyright: ignore[reportPrivateUsage] # str request, int response - has handler - assert executor._find_response_handler(["test"], 3.14) is not None # pyright: ignore[reportPrivateUsage] # list request, float response - has handler + assert executor._find_response_handler("test", 42) is not None # pyright: ignore[reportPrivateUsage] # str request, int response - has handler + assert executor._find_response_handler(["test"], 3.14) is not None # pyright: ignore[reportPrivateUsage] # list request, float response - has handler def test_is_request_supported_with_exact_matches(self): """Test is_request_supported with exact type matches.""" @@ -797,7 +797,7 @@ def test_response_handler_with_explicit_types(self): async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass - spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + spec = test_handler._response_handler_spec # type: ignore[attr-defined, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] assert spec["name"] == "test_handler" assert spec["request_type"] is str assert spec["response_type"] is int @@ -809,7 +809,7 @@ def test_response_handler_with_explicit_output_types(self): async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass - spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + spec = test_handler._response_handler_spec # type: ignore[attr-defined, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] assert spec["request_type"] is str assert spec["response_type"] is int assert bool in spec["output_types"] @@ -822,7 +822,7 @@ def test_response_handler_with_union_types(self): async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass - spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + spec = test_handler._response_handler_spec # type: ignore[attr-defined, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] assert spec["request_type"] == str | int assert spec["response_type"] == bool | float @@ -833,7 +833,7 @@ def test_response_handler_with_string_forward_references(self): async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass - spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + spec = test_handler._response_handler_spec # type: ignore[attr-defined, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] assert spec["request_type"] is str assert spec["response_type"] is int @@ -917,7 +917,7 @@ async def handle_response(self, original_request: Any, response: Any, ctx: Workf response_handler_func = executor._response_handlers[(str, int)] # type: ignore[reportAttributeAccessIssue] # Call the handler - asyncio.run(response_handler_func("test_request", 42, None)) # type: ignore[reportArgumentType] + asyncio.run(response_handler_func("test_request", 42, None)) # type: ignore[arg-type, reportArgumentType] # ty: ignore[invalid-argument-type] assert executor.handled_request == "test_request" assert executor.handled_response == 42 diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index 4fef26bd2d2..4b10f153db4 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -171,7 +171,7 @@ async def send_message( runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash") edge_runner = RecordingEdgeRunner() - runner._edge_runner_map = {"source": [edge_runner]} # type: ignore[assignment] + runner._edge_runner_map = {"source": [edge_runner]} # type: ignore[assignment, list-item] # ty: ignore[invalid-assignment] for index in range(5): await ctx.send_message(WorkflowMessage(data=MockMessage(data=index), source_id="source")) @@ -216,7 +216,7 @@ async def send_message( blocking_edge_runner = BlockingEdgeRunner() probe_edge_runner = ProbeEdgeRunner() - runner._edge_runner_map = {"source": [blocking_edge_runner, probe_edge_runner]} # type: ignore[assignment] + runner._edge_runner_map = {"source": [blocking_edge_runner, probe_edge_runner]} # type: ignore[assignment, list-item] # ty: ignore[invalid-assignment] await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="source")) @@ -520,7 +520,7 @@ async def create_checkpoint( ) return await self._storage.save(checkpoint) - async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pyright: ignore[reportIncompatibleMethodOverride] + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] # pyright: ignore[reportIncompatibleMethodOverride] try: return await self._storage.load(checkpoint_id) except WorkflowCheckpointException: diff --git a/python/packages/core/tests/workflow/test_serialization.py b/python/packages/core/tests/workflow/test_serialization.py index ed6316ecbeb..98555f0febf 100644 --- a/python/packages/core/tests/workflow/test_serialization.py +++ b/python/packages/core/tests/workflow/test_serialization.py @@ -311,7 +311,7 @@ def test_switch_case_edge_group_serialization(self) -> None: SwitchCaseEdgeGroupCase(condition=lambda x: x > 0, target_id="positive"), SwitchCaseEdgeGroupDefault(target_id="default"), ] - edge_group = SwitchCaseEdgeGroup(source_id="source", cases=cases) + edge_group = SwitchCaseEdgeGroup(source_id="source", cases=cases) # type: ignore[arg-type] # Test to_dict data = edge_group.to_dict() @@ -515,7 +515,7 @@ def is_positive(x: int) -> bool: SwitchCaseEdgeGroupCase(condition=is_positive, target_id="positive"), SwitchCaseEdgeGroupDefault(target_id="default"), ] - edge_group = SwitchCaseEdgeGroup(source_id="source", cases=cases) + edge_group = SwitchCaseEdgeGroup(source_id="source", cases=cases) # type: ignore[arg-type] # Test to_dict data = edge_group.to_dict() diff --git a/python/packages/core/tests/workflow/test_strict_mode_event_labeling.py b/python/packages/core/tests/workflow/test_strict_mode_event_labeling.py index d1de5c3cb06..f5682bd1ea8 100644 --- a/python/packages/core/tests/workflow/test_strict_mode_event_labeling.py +++ b/python/packages/core/tests/workflow/test_strict_mode_event_labeling.py @@ -25,7 +25,7 @@ async def _start(messages: list[Message], ctx: WorkflowContext[str, str]) -> Non @executor -async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: +async def _downstream(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("from-downstream") diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index 7bf38a06f32..f5e2e8fa26f 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -148,7 +148,7 @@ async def handle_domain_response( self, original_request: DomainCheckRequest, is_approved: bool, - ctx: WorkflowContext[Never, ValidationResult], + ctx: WorkflowContext[Never, ValidationResult], # type: ignore[valid-type] ) -> None: """Handle domain check response with correlation.""" # Use the original email from the correlated response @@ -495,7 +495,7 @@ async def handle_response( self, original_request: CheckpointRequest, response: str, - ctx: WorkflowContext[Never, bool], + ctx: WorkflowContext[Never, bool], # type: ignore[valid-type] ) -> None: self._responses.append(response) if len(self._responses) == 1: @@ -643,7 +643,7 @@ def __init__(self) -> None: super().__init__(id="finalizer") @handler - async def run(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + async def run(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output(f"final: {message}") progress = _ProgressEmitter() @@ -666,7 +666,7 @@ def __init__(self) -> None: self.received: list[str] = [] @handler - async def run(self, message: str, ctx: WorkflowContext[Never, str]) -> None: + async def run(self, message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] self.received.append(message) await ctx.yield_output(message) diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index f94bd9d52e8..cc3e7bf4ea3 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -45,17 +45,17 @@ def test_normalize_type_to_list_union_pipe_syntax() -> None: def test_normalize_type_to_list_union_typing_syntax() -> None: """Test normalize_type_to_list with Union[] from typing module.""" - result = normalize_type_to_list(Union[str, int]) # pyright: ignore[reportArgumentType] + result = normalize_type_to_list(Union[str, int]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] assert set(result) == {str, int} - result = normalize_type_to_list(Union[str, int, bool]) # pyright: ignore[reportArgumentType] + result = normalize_type_to_list(Union[str, int, bool]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] assert set(result) == {str, int, bool} def test_normalize_type_to_list_optional() -> None: """Test normalize_type_to_list with Optional types (Union[T, None]).""" # Optional[str] is Union[str, None] - result = normalize_type_to_list(Optional[str]) # pyright: ignore[reportArgumentType] + result = normalize_type_to_list(Optional[str]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] assert str in result assert type(None) in result assert len(result) == 2 @@ -125,20 +125,20 @@ class MyCustomType: assert result is MyCustomType result = resolve_type_annotation("MyCustomType | str", {"MyCustomType": MyCustomType, "str": str}) - assert set(result.__args__) == {MyCustomType, str} # type: ignore[union-attr] + assert set(result.__args__) == {MyCustomType, str} # type: ignore[union-attr] # ty: ignore[unresolved-attribute] def test_resolve_type_annotation_string_typing_union() -> None: """Test resolve_type_annotation resolves Union[] syntax in strings.""" result = resolve_type_annotation("Union[str, int]", {"str": str, "int": int}) - assert set(result.__args__) == {str, int} # type: ignore[union-attr] + assert set(result.__args__) == {str, int} # type: ignore[union-attr] # ty: ignore[unresolved-attribute] def test_resolve_type_annotation_string_optional() -> None: """Test resolve_type_annotation resolves Optional[] syntax in strings.""" result = resolve_type_annotation("Optional[str]", {"str": str}) - assert str in result.__args__ # type: ignore[union-attr] - assert type(None) in result.__args__ # type: ignore[union-attr] + assert str in result.__args__ # type: ignore[union-attr] # ty: ignore[unresolved-attribute] + assert type(None) in result.__args__ # type: ignore[union-attr] # ty: ignore[unresolved-attribute] def test_resolve_type_annotation_unresolvable_raises() -> None: diff --git a/python/packages/core/tests/workflow/test_validation.py b/python/packages/core/tests/workflow/test_validation.py index a9f62f35a35..c2f85614c85 100644 --- a/python/packages/core/tests/workflow/test_validation.py +++ b/python/packages/core/tests/workflow/test_validation.py @@ -50,7 +50,7 @@ async def handle_any(self, message: Any, ctx: WorkflowContext[Any]) -> None: class NoOutputTypesExecutor(Executor): @handler async def handle_message(self, message: str, ctx: WorkflowContext) -> None: - await ctx.send_message("processed") # type: ignore[arg-type] + await ctx.send_message("processed") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class MultiTypeExecutor(Executor): @@ -187,7 +187,7 @@ def test_disconnected_start_executor_not_in_graph(): def test_missing_start_executor(): with pytest.raises(TypeError): - WorkflowBuilder() # type: ignore[call-arg] + WorkflowBuilder() # type: ignore[call-arg] # ty: ignore[missing-argument] def test_workflow_validation_error_base_class(): @@ -498,7 +498,7 @@ def test_handler_ctx_invalid_t_out_entries_raises() -> None: class BadExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler # pyright: ignore[reportUnknownArgumentType] - async def handle(self, message: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type] + async def handle(self, message: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type] # ty: ignore[invalid-type-form] pass assert "invalid type entry" in str(exc.value) diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 27f24d26f9c..f9791350a12 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -967,7 +967,7 @@ async def test_workflow_run_inflight_messages_guard(simple_executor: Executor) - test_message = WorkflowMessage(data="test", source_id="test", target_id=None) # Simulate an aborted prior run by leaving a message in the runner context. - workflow._runner.context._messages["test"] = [test_message] + workflow._runner.context._messages["test"] = [test_message] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert await workflow._runner.context.has_messages() with pytest.raises(RuntimeError, match="in-flight executor messages"): @@ -1158,7 +1158,7 @@ async def test_output_executors_with_nonexistent_executor_id() -> None: # Designate a nonexistent executor so the workflow-level filter drops every yield. workflow._output_designation = OutputDesignation(outputs=frozenset({"nonexistent_executor"})) # type: ignore[attr-defined] - workflow._runner.context.set_yield_output_classifier(workflow._output_designation.classify) # type: ignore[attr-defined,reportPrivateUsage] + workflow._runner.context.set_yield_output_classifier(workflow._output_designation.classify) # type: ignore[attr-defined, reportPrivateUsage] result = await workflow.run(NumberMessage(data=0)) outputs = result.get_outputs() @@ -1255,7 +1255,7 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None from agent_framework._workflows._workflow import OutputDesignation workflow._output_designation = OutputDesignation(outputs=frozenset({"other_executor"})) # type: ignore[attr-defined] - workflow._runner.context.set_yield_output_classifier(workflow._output_designation.classify) # type: ignore[attr-defined,reportPrivateUsage] + workflow._runner.context.set_yield_output_classifier(workflow._output_designation.classify) # type: ignore[attr-defined, reportPrivateUsage] # Send approval response via streaming responses = {request_events[0].request_id: ApprovalMessage(approved=True)} diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 67cea0ed38a..9b3b8799b97 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -273,7 +273,7 @@ async def test_end_to_end_request_info_handling(self): assert request_event.get("type") == "request_info" assert deserialize_type(request_event.get("response_type")) is str - deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) # ty: ignore[invalid-argument-type] assert deserialized_args.request_id == request_function_call.call_id assert isinstance(deserialized_args.request_event, WorkflowEvent) assert deserialized_args.request_event.type == "request_info" @@ -327,7 +327,7 @@ def test_request_info_dataclass_arguments_are_serialized_when_content_is_created assert deserialize_type(request_event.get("response_type")) is str assert request_event.get("data") == HandoffRequest(target_agent="helper", reason="overflow") - deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(request_function_call.arguments) # ty: ignore[invalid-argument-type] assert deserialized_args.request_id == "request_123" assert isinstance(deserialized_args.request_event, WorkflowEvent) assert deserialized_args.request_event.type == "request_info" @@ -469,12 +469,12 @@ async def handle_response( self, original_request: Content, response: Content, - ctx: WorkflowContext[Never, AgentResponse], + ctx: WorkflowContext[Never, AgentResponse], # type: ignore[valid-type] ) -> None: assert response.type == "function_approval_response" assert response.id == approval_id # type: ignore[attr-defined] approved = bool(response.approved) # type: ignore[attr-defined] - tool_name = original_request.function_call.name # type: ignore[attr-defined] + tool_name = original_request.function_call.name # type: ignore[attr-defined, union-attr] # ty: ignore[unresolved-attribute] await ctx.yield_output( AgentResponse( messages=[ @@ -543,12 +543,12 @@ async def handle_response( self, original_request: Content, response: Content, - ctx: WorkflowContext[Never, AgentResponse], + ctx: WorkflowContext[Never, AgentResponse], # type: ignore[valid-type] ) -> None: assert response.type == "function_approval_response" assert response.id == approval_id # type: ignore[attr-defined] approved = bool(response.approved) # type: ignore[attr-defined] - tool_name = original_request.function_call.name # type: ignore[attr-defined] + tool_name = original_request.function_call.name # type: ignore[attr-defined, union-attr] # ty: ignore[unresolved-attribute] await ctx.yield_output( AgentResponse( messages=[ @@ -608,7 +608,7 @@ async def handle_response( self, original_request: HandoffRequest, response: str, - ctx: WorkflowContext[Never, AgentResponse], + ctx: WorkflowContext[Never, AgentResponse], # type: ignore[valid-type] ) -> None: captured["original"] = original_request captured["response"] = response @@ -651,7 +651,7 @@ async def handle_response( assert request_payload.get("type") == "request_info" assert request_payload.get("data") == HandoffRequest(target_agent="helper", reason="overflow") - deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments) + deserialized_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(function_call.arguments) # ty: ignore[invalid-argument-type] assert deserialized_args.request_id == request_id assert isinstance(deserialized_args.request_event, WorkflowEvent) assert deserialized_args.request_event.type == "request_info" @@ -752,7 +752,7 @@ async def test_workflow_as_agent_yield_output_surfaces_as_agent_response(self) - """ @executor - async def yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] # Extract text from input for demonstration input_text = messages[0].text if messages else "no input" await ctx.yield_output(f"processed: {input_text}") @@ -777,7 +777,7 @@ async def test_workflow_as_agent_yield_output_surfaces_in_run_stream(self) -> No """Test that ctx.yield_output() surfaces as AgentResponseUpdate when streaming.""" @executor - async def yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("first output") await ctx.yield_output("second output") @@ -797,7 +797,7 @@ async def test_workflow_as_agent_yield_output_with_content_types(self) -> None: """Test that yield_output preserves different content types (Content, Content, etc.).""" @executor - async def content_yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, Content]) -> None: + async def content_yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, Content]) -> None: # type: ignore[valid-type] # Yield different content types await ctx.yield_output(Content.from_text(text="text content")) await ctx.yield_output(Content.from_data(data=b"binary data", media_type="application/octet-stream")) @@ -825,7 +825,7 @@ async def test_workflow_as_agent_yield_output_with_chat_message(self) -> None: """Test that yield_output with Message preserves the message structure.""" @executor - async def chat_message_executor(messages: list[Message], ctx: WorkflowContext[Never, Message]) -> None: + async def chat_message_executor(messages: list[Message], ctx: WorkflowContext[Never, Message]) -> None: # type: ignore[valid-type] msg = Message( role="assistant", contents=[Content.from_text(text="response text")], @@ -856,7 +856,8 @@ def __str__(self) -> str: @executor async def raw_yielding_executor( - messages: list[Message], ctx: WorkflowContext[Never, Content | CustomData | str] + messages: list[Message], + ctx: WorkflowContext[Never, Content | CustomData | str], # type: ignore[valid-type] ) -> None: # Yield different types of data await ctx.yield_output("simple string") @@ -892,7 +893,7 @@ async def test_workflow_as_agent_yield_output_with_list_of_chat_messages(self) - """ @executor - async def list_yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, list[Message]]) -> None: + async def list_yielding_executor(messages: list[Message], ctx: WorkflowContext[Never, list[Message]]) -> None: # type: ignore[valid-type] # Yield a list of Messages (as SequentialBuilder does) msg_list = [ Message(role="user", contents=["first message"]), @@ -1233,7 +1234,7 @@ def __init__(self, name: str, response_text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() @overload @@ -1343,7 +1344,7 @@ def __init__(self, name: str, response_text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() @overload @@ -1979,7 +1980,7 @@ def __init__( def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: # type: ignore[override] # pyrefly: ignore[bad-override] # ty: ignore[invalid-method-override] return AgentSession() def _next_request_id(self) -> str: diff --git a/python/packages/core/tests/workflow/test_workflow_agent_intermediate.py b/python/packages/core/tests/workflow/test_workflow_agent_intermediate.py index 4fc66135f5d..978abb9967f 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent_intermediate.py +++ b/python/packages/core/tests/workflow/test_workflow_agent_intermediate.py @@ -40,7 +40,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[str, str]) -> None: await ctx.send_message("downstream") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("FINAL") workflow = ( @@ -58,8 +58,8 @@ async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: async for update in agent.run("hi", stream=True): updates.append(update) - text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") - reasoning_text = " ".join(c.text for u in updates for c in u.contents if c.type == "text_reasoning") + text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + reasoning_text = " ".join(c.text for u in updates for c in u.contents if c.type == "text_reasoning") # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "intermediate progress" in text assert "FINAL" in text @@ -76,7 +76,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[str, str]) -> None: await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("the-answer") workflow = ( @@ -106,14 +106,14 @@ async def hidden(messages: list[Message], ctx: WorkflowContext[str, str]) -> Non await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("visible-answer") workflow = WorkflowBuilder(start_executor=hidden, output_from=[terminal]).add_edge(hidden, terminal).build() agent = workflow.as_agent("test") response = await agent.run("hi") - all_text = " ".join(c.text for m in response.messages for c in m.contents if hasattr(c, "text")) + all_text = " ".join(c.text for m in response.messages for c in m.contents if hasattr(c, "text")) # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert response.text == "visible-answer" assert "hidden-progress" not in all_text @@ -129,7 +129,7 @@ async def hidden(messages: list[Message], ctx: WorkflowContext[str, str]) -> Non await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("visible-answer") workflow = WorkflowBuilder(start_executor=hidden, output_from=[terminal]).add_edge(hidden, terminal).build() @@ -139,7 +139,7 @@ async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: async for update in agent.run("hi", stream=True): updates.append(update) - all_text = " ".join(c.text for u in updates for c in u.contents if hasattr(c, "text")) + all_text = " ".join(c.text for u in updates for c in u.contents if hasattr(c, "text")) # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "visible-answer" in all_text assert "hidden-progress" not in all_text @@ -150,7 +150,7 @@ async def test_workflow_agent_data_event_emit_factory_still_forwarded() -> None: """Even the deprecated WorkflowEvent.emit() / type='data' path is forwarded.""" @executor - async def emit_data_alias(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def emit_data_alias(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) await ctx.add_event(WorkflowEvent.emit("emit_data_alias", "data-alias-payload")) @@ -163,7 +163,7 @@ async def emit_data_alias(messages: list[Message], ctx: WorkflowContext[Never, s async for update in agent.run("hi", stream=True): updates.append(update) - text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") + text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "data-alias-payload" in text @@ -186,7 +186,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[str, AgentResponse] await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("done") workflow = ( @@ -211,7 +211,7 @@ async def test_workflow_agent_terminal_text_stays_text_not_reasoning() -> None: """A designated executor's text yield surfaces as Content.text.""" @executor - async def only(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def only(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("the-answer") workflow = WorkflowBuilder(start_executor=only, output_from=[only]).build() @@ -228,7 +228,7 @@ async def test_workflow_agent_non_streaming_rejects_terminal_update() -> None: """A terminal event carrying AgentResponseUpdate is streaming-only and invalid in run().""" @executor - async def emit(messages: list[Message], ctx: WorkflowContext[Never, AgentResponseUpdate]) -> None: + async def emit(messages: list[Message], ctx: WorkflowContext[Never, AgentResponseUpdate]) -> None: # type: ignore[valid-type] await ctx.yield_output(AgentResponseUpdate(contents=[Content.from_text(text="partial")], role="assistant")) workflow = WorkflowBuilder(start_executor=emit, output_from=[emit]).build() @@ -248,7 +248,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[str, AgentResponseU await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.yield_output("FINAL") workflow = ( @@ -278,7 +278,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[str, AgentResponseU await ctx.send_message("forward") @executor - async def terminal(message: str, ctx: WorkflowContext[Never, AgentResponseUpdate]) -> None: + async def terminal(message: str, ctx: WorkflowContext[Never, AgentResponseUpdate]) -> None: # type: ignore[valid-type] await ctx.yield_output( AgentResponseUpdate(contents=[Content.from_text(text="terminal-chunk")], role="assistant") ) @@ -298,8 +298,8 @@ async def terminal(message: str, ctx: WorkflowContext[Never, AgentResponseUpdate async for update in agent.run("hi", stream=True): updates.append(update) - text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") - reasoning_text = " ".join(c.text for u in updates for c in u.contents if c.type == "text_reasoning") + text = " ".join(c.text for u in updates for c in u.contents if c.type == "text") # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] + reasoning_text = " ".join(c.text for u in updates for c in u.contents if c.type == "text_reasoning") # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "intermediate-chunk" in text assert "terminal-chunk" in text @@ -313,7 +313,7 @@ async def test_workflow_agent_drops_orchestration_internal_events() -> None: be stringified by the generic fallback path and leak into response history.""" @executor - async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] # Construct typed orchestration-internal events directly to assert they get # dropped at the agent boundary regardless of payload. await ctx.add_event(WorkflowEvent("group_chat", data={"orchestrator": "details"})) # type: ignore[arg-type] @@ -325,7 +325,7 @@ async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> Non agent = workflow.as_agent("test") response = await agent.run("hi") - all_text = " ".join(c.text for m in response.messages for c in m.contents if hasattr(c, "text")) + all_text = " ".join(c.text for m in response.messages for c in m.contents if hasattr(c, "text")) # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "orchestrator" not in all_text assert "agent_b" not in all_text assert "plan" not in all_text @@ -337,7 +337,7 @@ async def test_workflow_agent_drops_orchestration_internal_events_streaming() -> """Streaming counterpart — orchestration-internal events stay inside the workflow.""" @executor - async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: + async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.add_event(WorkflowEvent("group_chat", data={"orchestrator": "details"})) # type: ignore[arg-type] await ctx.yield_output("FINAL") @@ -348,6 +348,6 @@ async def emit(messages: list[Message], ctx: WorkflowContext[Never, str]) -> Non async for update in agent.run("hi", stream=True): updates.append(update) - all_text = " ".join(c.text for u in updates for c in u.contents if hasattr(c, "text")) + all_text = " ".join(c.text for u in updates for c in u.contents if hasattr(c, "text")) # type: ignore[misc] # pyrefly: ignore[no-matching-overload] # ty: ignore[no-matching-overload] assert "orchestrator" not in all_text assert "FINAL" in all_text diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 873d6e7c73c..2f92aceb77e 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -61,7 +61,7 @@ def run( async def _run_impl(self, messages: AgentRunInputs | None = None) -> AgentResponse: norm: list[Message] = [] if messages: - for m in messages: # type: ignore[union-attr] + for m in messages: # type: ignore[union-attr] # ty: ignore[not-iterable] if isinstance(m, Message): norm.append(m) elif isinstance(m, str): @@ -112,7 +112,7 @@ async def mock_handler(self, messages: list[MockMessage], ctx: WorkflowContext[M def test_workflow_builder_without_start_executor_throws(): """Test creating a workflow builder without a start executor.""" with pytest.raises(TypeError): - WorkflowBuilder() # type: ignore[call-arg] + WorkflowBuilder() # type: ignore[call-arg] # ty: ignore[missing-argument] def test_workflow_builder_fluent_api(): diff --git a/python/packages/core/tests/workflow/test_workflow_context.py b/python/packages/core/tests/workflow/test_workflow_context.py index 58898924351..6bb032428ee 100644 --- a/python/packages/core/tests/workflow/test_workflow_context.py +++ b/python/packages/core/tests/workflow/test_workflow_context.py @@ -110,7 +110,7 @@ async def test_executor_emits_normal_event() -> None: class _TestEvent(WorkflowEvent): def __init__(self, data: Any = None) -> None: - super().__init__("test_event", data=data) # type: ignore[arg-type] + super().__init__("test_event", data=data) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] async def test_workflow_context_type_annotations_no_parameter() -> None: @@ -193,7 +193,7 @@ async def func1(text: str, ctx: WorkflowContext[str]) -> None: await ctx.send_message("world") @executor(id="func2") - async def func2(text: str, ctx: WorkflowContext[Never, str]) -> None: + async def func2(text: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.add_event(_TestEvent(data=text)) await ctx.yield_output(text) @@ -211,7 +211,7 @@ async def func1(self, text: str, ctx: WorkflowContext[str]) -> None: class _exec2(Executor): @handler - async def func2(self, text: str, ctx: WorkflowContext[Never, str]) -> None: + async def func2(self, text: str, ctx: WorkflowContext[Never, str]) -> None: # type: ignore[valid-type] await ctx.add_event(_TestEvent(data=text)) await ctx.yield_output(text) @@ -284,7 +284,7 @@ async def test_workflow_context_invalid_type_parameter_error() -> None: with pytest.raises(ValueError, match="invalid type entry"): @executor(id="bad_func") - async def bad_func(text: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type] + async def bad_func(text: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type] # ty: ignore[invalid-type-form] pass # Test class-based executor with invalid type parameter @@ -292,12 +292,12 @@ async def bad_func(text: str, ctx: WorkflowContext[123]) -> None: # type: ignor class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler # pyright: ignore[reportUnknownArgumentType] - async def bad_handler(self, text: str, ctx: WorkflowContext[456]) -> None: # type: ignore[valid-type] + async def bad_handler(self, text: str, ctx: WorkflowContext[456]) -> None: # type: ignore[valid-type] # ty: ignore[invalid-type-form] pass # Test two-parameter WorkflowContext with invalid workflow output type with pytest.raises(ValueError, match="invalid type entry"): @executor(id="bad_func2") - async def bad_func2(text: str, ctx: WorkflowContext[str, 789]) -> None: # type: ignore[valid-type] + async def bad_func2(text: str, ctx: WorkflowContext[str, 789]) -> None: # type: ignore[valid-type] # ty: ignore[invalid-type-form] pass diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 7bfa47a79f5..83275da8589 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -240,7 +240,7 @@ async def test_kwargs_stored_in_state() -> None: class _StateInspector(Executor): @handler - async def inspect(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: + async def inspect(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: # type: ignore[valid-type] nonlocal stored_kwargs stored_kwargs = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) await ctx.yield_output(AgentResponse(messages=msgs)) @@ -266,7 +266,7 @@ async def test_empty_kwargs_stored_as_empty_dict() -> None: class _StateChecker(Executor): @handler - async def check(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: + async def check(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: # type: ignore[valid-type] nonlocal stored_kwargs stored_kwargs = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) await ctx.yield_output(AgentResponse(messages=msgs)) @@ -432,8 +432,8 @@ async def test_handoff_kwargs_flow_to_agents() -> None: workflow = ( HandoffBuilder(termination_condition=lambda conv: len(conv) >= 4) - .participants([agent1, agent2]) # type: ignore[list-item] - .with_start_agent(agent1) # type: ignore[arg-type] + .participants([agent1, agent2]) # type: ignore[list-item] # ty: ignore[invalid-argument-type] + .with_start_agent(agent1) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] .with_autonomous_mode() .build() ) @@ -710,7 +710,7 @@ class _StateReader(Executor): """Executor that reads kwargs from State for verification.""" @handler - async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: + async def read_kwargs(self, msgs: list[Message], ctx: WorkflowContext[Never, AgentResponse]) -> None: # type: ignore[valid-type] kwargs_from_state = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) captured_kwargs_from_state.append(kwargs_from_state or {}) await ctx.yield_output(AgentResponse(messages=msgs)) diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index b098fa27718..eb2772be124 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -143,29 +143,29 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) assert len(spans) == 3 # Check workflow span - workflow_span = next(s for s in spans if s.name == "workflow.run") - assert workflow_span.kind == trace.SpanKind.INTERNAL - assert workflow_span.attributes is not None - assert workflow_span.attributes.get(OtelAttr.WORKFLOW_ID) == "test-workflow-123" - assert workflow_span.events is not None - event_names = [event.name for event in workflow_span.events] + workflow_span = next(s for s in spans if s.name == "workflow.run") # type: ignore[assignment, misc] + assert workflow_span.kind == trace.SpanKind.INTERNAL # type: ignore[attr-defined] + assert workflow_span.attributes is not None # type: ignore[attr-defined] + assert workflow_span.attributes.get(OtelAttr.WORKFLOW_ID) == "test-workflow-123" # type: ignore[attr-defined] + assert workflow_span.events is not None # type: ignore[attr-defined] + event_names = [event.name for event in workflow_span.events] # type: ignore[attr-defined] assert "workflow.started" in event_names # Check processing span - span name uses format "executor.process {executor_id}" - processing_span = next(s for s in spans if s.name == "executor.process executor-456") - assert processing_span.kind == trace.SpanKind.INTERNAL - assert processing_span.attributes is not None - assert processing_span.attributes.get("executor.id") == "executor-456" - assert processing_span.attributes.get("executor.type") == "TestExecutor" - assert processing_span.attributes.get("message.type") == str(MessageType.STANDARD) - assert processing_span.attributes.get("message.payload_type") == "TestMessage" + processing_span = next(s for s in spans if s.name == "executor.process executor-456") # type: ignore[assignment, misc] + assert processing_span.kind == trace.SpanKind.INTERNAL # type: ignore[attr-defined] + assert processing_span.attributes is not None # type: ignore[attr-defined] + assert processing_span.attributes.get("executor.id") == "executor-456" # type: ignore[attr-defined] + assert processing_span.attributes.get("executor.type") == "TestExecutor" # type: ignore[attr-defined] + assert processing_span.attributes.get("message.type") == str(MessageType.STANDARD) # type: ignore[attr-defined] + assert processing_span.attributes.get("message.payload_type") == "TestMessage" # type: ignore[attr-defined] # Check sending span - sending_span = next(s for s in spans if s.name == "message.send") - assert sending_span.kind == trace.SpanKind.PRODUCER - assert sending_span.attributes is not None - assert sending_span.attributes.get("message.type") == "ResponseMessage" - assert sending_span.attributes.get("message.destination_executor_id") == "target-789" + sending_span = next(s for s in spans if s.name == "message.send") # type: ignore[assignment, misc] + assert sending_span.kind == trace.SpanKind.PRODUCER # type: ignore[attr-defined] + assert sending_span.attributes is not None # type: ignore[attr-defined] + assert sending_span.attributes.get("message.type") == "ResponseMessage" # type: ignore[attr-defined] + assert sending_span.attributes.get("message.destination_executor_id") == "target-789" # type: ignore[attr-defined] async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> None: diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index bf2e277d107..4bf2507a1d5 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -153,7 +153,7 @@ class Completer(Executor): """Executor that completes immediately with provided data for testing.""" @handler - async def run(self, msg: str, ctx: WorkflowContext[Never, str]) -> None: # pragma: no cover + async def run(self, msg: str, ctx: WorkflowContext[Never, str]) -> None: # pragma: no cover # type: ignore[valid-type] await ctx.yield_output(msg) diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 4507d0112d8..bbd9f4ec9be 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -15,7 +15,7 @@ from agent_framework import ( FunctionTool as AFFunctionTool, ) -from agent_framework._feature_stage import ( # type: ignore[reportPrivateUsage] +from agent_framework._feature_stage import ( ExperimentalFeature, experimental, ) @@ -42,9 +42,9 @@ ) if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover @experimental(feature_id=ExperimentalFeature.DECLARATIVE_AGENTS) @@ -708,7 +708,7 @@ def _get_client(self, prompt_agent: PromptAgent) -> SupportsChatGetResponse: module = __import__(module_name, fromlist=[class_name]) agent_class = getattr(module, class_name) setup_dict[mapping["model_field"]] = prompt_agent.model.id - return agent_class(**setup_dict) # type: ignore[no-any-return] + return agent_class(**setup_dict) def _parse_chat_options(self, model: Model | None) -> dict[str, Any]: """Parse ModelOptions into chat options dictionary.""" @@ -753,7 +753,7 @@ def _parse_tool(self, tool_resource: Tool) -> AFFunctionTool | dict[str, Any]: for binding in tool_resource.bindings: if binding.name and (func := self.bindings.get(binding.name)): break - return AFFunctionTool( # type: ignore + return AFFunctionTool( name=tool_resource.name, # type: ignore description=tool_resource.description, # type: ignore input_model=tool_resource.parameters.to_json_schema() if tool_resource.parameters else None, diff --git a/python/packages/declarative/agent_framework_declarative/_models.py b/python/packages/declarative/agent_framework_declarative/_models.py index 6aa359d7f79..9a23560f99f 100644 --- a/python/packages/declarative/agent_framework_declarative/_models.py +++ b/python/packages/declarative/agent_framework_declarative/_models.py @@ -121,7 +121,7 @@ def from_dict( # Only dispatch if we're being called on the base Property class if cls is not Property: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # The YAML spec uses 'type' for the data type, but Property stores it as 'kind' if "type" in value: @@ -135,7 +135,7 @@ def from_dict( if kind == "object": return ObjectProperty.from_dict(value, dependencies=dependencies) # Default to Property for kind="property" or empty - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) class ArrayProperty(Property): @@ -235,7 +235,7 @@ def from_dict( # Filter out 'kind', 'type', 'name', and 'description' fields that may appear in YAML # but aren't PropertySchema params kwargs = {k: v for k, v in value.items() if k not in ("type", "kind", "name", "description")} - return SerializationMixin.from_dict.__func__(cls, kwargs, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, kwargs, dependencies=dependencies) def to_json_schema(self) -> dict[str, Any]: """Get a schema out of this PropertySchema to create pydantic models.""" @@ -287,26 +287,18 @@ def from_dict( # Only dispatch if we're being called on the base Connection class if cls is not Connection: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) kind = value.get("kind", "").lower() if kind == "reference": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - ReferenceConnection, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(ReferenceConnection, value, dependencies=dependencies) if kind == "remote": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - RemoteConnection, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(RemoteConnection, value, dependencies=dependencies) if kind in ("key", "apikey"): - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - ApiKeyConnection, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(ApiKeyConnection, value, dependencies=dependencies) if kind == "anonymous": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - AnonymousConnection, value, dependencies=dependencies - ) - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(AnonymousConnection, value, dependencies=dependencies) + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) class ReferenceConnection(Connection): @@ -525,13 +517,13 @@ def from_dict( # Only dispatch if we're being called on the base AgentDefinition class if cls is not AgentDefinition: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) kind = value.get("kind", "") if kind == "Prompt" or kind == "Agent": return PromptAgent.from_dict(value, dependencies=dependencies) # Default to AgentDefinition - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) ToolT = TypeVar("ToolT", bound="Tool") @@ -571,39 +563,25 @@ def from_dict( # Only dispatch if we're being called on the base Tool class if cls is not Tool: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) kind = value.get("kind", "") if kind == "function": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - FunctionTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(FunctionTool, value, dependencies=dependencies) if kind == "custom": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - CustomTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(CustomTool, value, dependencies=dependencies) if kind == "web_search": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - WebSearchTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(WebSearchTool, value, dependencies=dependencies) if kind == "file_search": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - FileSearchTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(FileSearchTool, value, dependencies=dependencies) if kind == "mcp": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - McpTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(McpTool, value, dependencies=dependencies) if kind == "openapi": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - OpenApiTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(OpenApiTool, value, dependencies=dependencies) if kind == "code_interpreter": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - CodeInterpreterTool, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(CodeInterpreterTool, value, dependencies=dependencies) # Default to base Tool class - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) class FunctionTool(Tool): @@ -901,18 +879,14 @@ def from_dict( # Only dispatch if we're being called on the base Resource class if cls is not Resource: # We're being called on a subclass, use the normal from_dict - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) kind = value.get("kind", "") if kind == "model": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - ModelResource, value, dependencies=dependencies - ) + return SerializationMixin.from_dict.__func__(ModelResource, value, dependencies=dependencies) if kind == "tool": - return SerializationMixin.from_dict.__func__( # type: ignore[attr-defined, no-any-return] - ToolResource, value, dependencies=dependencies - ) - return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) # type: ignore[attr-defined, no-any-return] + return SerializationMixin.from_dict.__func__(ToolResource, value, dependencies=dependencies) + return SerializationMixin.from_dict.__func__(cls, value, dependencies=dependencies) class ModelResource(Resource): diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index e6fc0a820d5..5fbc68bd04e 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -50,12 +50,12 @@ except (ImportError, RuntimeError): # ImportError: powerfx package not installed # RuntimeError: .NET runtime not available or misconfigured - Engine = None # type: ignore[assignment, misc] + Engine = None if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover logger = logging.getLogger(__name__) @@ -144,11 +144,11 @@ def visit(value: Any) -> None: names.update(_ENV_REFERENCE_RE.findall(value)) return if isinstance(value, Mapping): - for inner in cast(Mapping[Any, Any], value).values(): # type: ignore[redundant-cast] + for inner in cast(Mapping[Any, Any], value).values(): visit(inner) return if isinstance(value, list): - for item in cast(list[Any], value): # type: ignore[redundant-cast] + for item in cast(list[Any], value): visit(item) visit(node) @@ -240,7 +240,7 @@ def _make_powerfx_safe(value: Any) -> Any: return {str(k): _make_powerfx_safe(v) for k, v in value_dict.items()} if isinstance(value, list): - value_list = cast(list[Any], value) # type: ignore[redundant-cast] + value_list = cast(list[Any], value) return [_make_powerfx_safe(item) for item in value_list] # Try to convert objects with __dict__ or dataclass-style attributes @@ -537,10 +537,10 @@ def eval(self, expression: str) -> Any: original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] try: - CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) finally: - CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] + CultureInfo.CurrentCulture = original_culture except ValueError as e: error_msg = str(e) # Handle undefined variable errors gracefully by returning None @@ -722,7 +722,7 @@ def _eval_and_replace_message_text(self, inner_expr: str) -> str: """ messages: Any = self.eval(f"={inner_expr}") if isinstance(messages, list) and messages: - message_list = cast(list[Any], messages) # type: ignore[redundant-cast] + message_list = cast(list[Any], messages) last_msg: Any = message_list[-1] if isinstance(last_msg, dict): last_msg_dict = cast(dict[str, Any], last_msg) @@ -733,7 +733,7 @@ def _eval_and_replace_message_text(self, inner_expr: str) -> str: # Message.text concatenates text from all TextContent items contents_obj = last_msg_dict.get("contents", []) if isinstance(contents_obj, list): - contents = cast(list[Any], contents_obj) # type: ignore[redundant-cast] + contents = cast(list[Any], contents_obj) text_parts: list[str] = [] for content in contents: if isinstance(content, dict): @@ -1154,7 +1154,7 @@ async def _ensure_state_initialized( state.set("System.LastMessageText", trigger) elif not isinstance( trigger, - (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl), # pyright: ignore[reportUnknownArgumentType] + (ActionTrigger, ActionComplete, ConditionResult, LoopIterationResult, LoopControl), ): # Any other type - convert to string like .NET's DefaultTransform input_str = str(cast(Any, trigger)) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py index 21dda14116b..9a6814af60d 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py @@ -433,7 +433,7 @@ def _create_executors_for_actions( # Store the chain for later reference if first_executor is not None: - first_executor._chain_executors = executors_in_chain # type: ignore[attr-defined] + first_executor._chain_executors = executors_in_chain return first_executor diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py index a5ac585208e..d876c2c45ef 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py @@ -543,12 +543,12 @@ def _convert_to_type(self, value: Any, target_type: str) -> Any: if value is None: return [] if isinstance(value, list): - return cast(list[Any], value) # type: ignore[redundant-cast] + return cast(list[Any], value) if isinstance(value, str): try: parsed = json.loads(value) if isinstance(parsed, list): - return cast(list[Any], parsed) # type: ignore[redundant-cast] + return cast(list[Any], parsed) return [parsed] except json.JSONDecodeError: return [value] diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py index 73b66341ea3..1361ded0822 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_mcp.py @@ -173,7 +173,7 @@ def _format_outputs_for_send(parsed_results: list[Any]) -> str: if not parsed_results: return "" if all(isinstance(item, str) for item in parsed_results): - return "\n".join(parsed_results) # type: ignore[arg-type] + return "\n".join(parsed_results) if len(parsed_results) == 1: return json.dumps(parsed_results[0], ensure_ascii=False) return json.dumps(parsed_results, ensure_ascii=False) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_http_handler.py b/python/packages/declarative/agent_framework_declarative/_workflows/_http_handler.py index 90ff5b87b4c..6c1790509d1 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_http_handler.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_http_handler.py @@ -181,7 +181,7 @@ async def send(self, info: HttpRequestInfo) -> HttpRequestResult: params=params, headers=headers or None, content=content, - timeout=timeout, # type: ignore[arg-type] + timeout=timeout, ) # Preserve multi-value headers (e.g. multiple Set-Cookie) as list[str]. diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py index f61120a469b..eba3be4ea75 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py @@ -52,7 +52,7 @@ def message_text(messages: Any) -> str: if isinstance(messages, list): # List of messages - concatenate all text texts: list[str] = [] - message_list = cast(list[Any], messages) # type: ignore[redundant-cast] + message_list = cast(list[Any], messages) for msg in message_list: if isinstance(msg, str): texts.append(msg) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py index 76530f50dd1..9ba9dd964b2 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py @@ -284,7 +284,7 @@ def append(self, path: str, value: Any) -> None: if existing is None: self.set(path, [value]) elif isinstance(existing, list): - existing_list = cast(list[Any], existing) # type: ignore[redundant-cast] + existing_list = cast(list[Any], existing) existing_list.append(value) self.set(path, existing_list) else: diff --git a/python/packages/declarative/tests/test_declarative_loader.py b/python/packages/declarative/tests/test_declarative_loader.py index db78d33315f..823342f68ea 100644 --- a/python/packages/declarative/tests/test_declarative_loader.py +++ b/python/packages/declarative/tests/test_declarative_loader.py @@ -9,6 +9,7 @@ import pytest import yaml +from agent_framework_declarative._loader import ProviderTypeMapping from agent_framework_declarative._models import ( AgentDefinition, AgentManifest, @@ -25,6 +26,7 @@ McpServerToolNeverRequireApprovalMode, McpServerToolSpecifyApprovalMode, McpTool, + Model, ModelResource, ObjectProperty, OpenApiTool, @@ -706,6 +708,9 @@ def test_agent_factory_safe_mode_with_api_key_connection(self, monkeypatch): token = _safe_mode_context.set(True) # Ensure we're in safe mode try: result = agent_schema_dispatch(yaml_module.safe_load(yaml_content)) + assert isinstance(result, PromptAgent) + assert isinstance(result.model, Model) + assert isinstance(result.model.connection, ApiKeyConnection) # The API key should NOT be resolved (still has the PowerFx expression) assert result.model.connection.apiKey == "=Env.MY_API_KEY" @@ -741,6 +746,9 @@ def test_agent_factory_safe_mode_false_resolves_api_key(self, monkeypatch): token = _safe_mode_context.set(False) # Disable safe mode try: result = agent_schema_dispatch(yaml_module.safe_load(yaml_content)) + assert isinstance(result, PromptAgent) + assert isinstance(result.model, Model) + assert isinstance(result.model.connection, ApiKeyConnection) # The API key should be resolved from environment assert result.model.connection.apiKey == "secret-key-123" @@ -1127,11 +1135,13 @@ def test_additional_mappings_override_default(self): from agent_framework_declarative import AgentFactory # Define a custom provider mapping - custom_mappings = { + custom_mappings: dict[str, ProviderTypeMapping] = { "CustomProvider.Chat": { "package": "agent_framework.openai", "name": "OpenAIChatClient", "model_field": "model", + "endpoint_field": None, + "api_key_field": None, }, } @@ -1424,7 +1434,13 @@ async def test_response_format_in_default_options(self): prompt_agent = self._make_mock_prompt_agent(with_output_schema=True) mock_provider_class, mock_provider_instance = self._make_mock_provider() - mapping = {"package": "some_module", "name": "SomeProvider"} + mapping: ProviderTypeMapping = { + "package": "some_module", + "name": "SomeProvider", + "model_field": "model", + "endpoint_field": None, + "api_key_field": None, + } factory = AgentFactory() original_import = builtins.__import__ @@ -1458,7 +1474,13 @@ async def test_no_default_options_without_output_schema(self): prompt_agent = self._make_mock_prompt_agent(with_output_schema=False) mock_provider_class, mock_provider_instance = self._make_mock_provider() - mapping = {"package": "some_module", "name": "SomeProvider"} + mapping: ProviderTypeMapping = { + "package": "some_module", + "name": "SomeProvider", + "model_field": "model", + "endpoint_field": None, + "api_key_field": None, + } factory = AgentFactory() original_import = builtins.__import__ diff --git a/python/packages/declarative/tests/test_declarative_models.py b/python/packages/declarative/tests/test_declarative_models.py index 8768b5b01ca..85fa58a32f5 100644 --- a/python/packages/declarative/tests/test_declarative_models.py +++ b/python/packages/declarative/tests/test_declarative_models.py @@ -3,6 +3,7 @@ """Tests for MAML model classes.""" import sys +from typing import Any, cast import pytest @@ -156,6 +157,7 @@ def test_array_property_creation(self): array_prop = ArrayProperty(name="test_array", kind="array", items=items, required=True) assert array_prop.name == "test_array" assert array_prop.kind == "array" + assert array_prop.items is not None assert array_prop.items.name == "item" assert array_prop.required is True @@ -167,6 +169,7 @@ def test_array_property_from_dict(self): "required": True, } array_prop = ArrayProperty.from_dict(data) + assert isinstance(array_prop, ArrayProperty) assert array_prop.name == "test_array" assert array_prop.kind == "array" assert isinstance(array_prop.items, Property) @@ -198,6 +201,7 @@ def test_object_property_from_dict(self): "required": True, } obj_prop = ObjectProperty.from_dict(data) + assert isinstance(obj_prop, ObjectProperty) assert obj_prop.name == "test_object" assert obj_prop.kind == "object" assert len(obj_prop.properties) == 2 @@ -215,6 +219,7 @@ def test_object_property_with_dict_properties(self): }, } obj_prop = ObjectProperty.from_dict(data) + assert isinstance(obj_prop, ObjectProperty) assert obj_prop.name == "person" assert obj_prop.kind == "object" assert len(obj_prop.properties) == 3 @@ -224,7 +229,7 @@ def test_object_property_with_dict_properties(self): assert prop_names == {"name", "email", "age"} # Check specific property - name_prop = next(p for p in obj_prop.properties if p.name == "name") + name_prop = next(p for p in obj_prop.properties if p.name == "name") # pyrefly: ignore[not-iterable] assert name_prop.kind == "string" assert name_prop.required is True @@ -302,7 +307,7 @@ class TestConnection: """Tests for Connection base class.""" def test_connection_creation(self): - conn = Connection(kind="base") + conn = Connection(kind=cast("Any", "base")) assert conn.kind == "base" def test_connection_from_dict(self): @@ -412,6 +417,7 @@ def test_model_with_connection(self): } model = Model.from_dict(data) assert model.id == "gpt-4" + assert model.connection is not None assert model.connection.kind == "reference" @@ -720,6 +726,8 @@ def test_mcp_tool_approval_mode_equivalence(self): tool_full = McpTool.from_dict(data_full) # Both should produce the same result + assert tool_simplified.approvalMode is not None + assert tool_full.approvalMode is not None assert tool_simplified.approvalMode.kind == tool_full.approvalMode.kind assert tool_simplified.approvalMode.kind == "never" @@ -800,9 +808,9 @@ def test_prompt_agent_from_dict(self): "model": {"id": "gpt-4"}, } agent = PromptAgent.from_dict(data) + assert isinstance(agent, PromptAgent) assert agent.name == "prompt-agent" assert isinstance(agent.model, Model) - assert isinstance(agent.model, Model) def test_prompt_agent_with_tools(self): data = { @@ -814,6 +822,8 @@ def test_prompt_agent_with_tools(self): ], } agent = PromptAgent.from_dict(data) + assert isinstance(agent, PromptAgent) + assert agent.tools is not None assert len(agent.tools) == 2 # Tools are converted via Tool.from_dict, type depends on 'kind' assert agent.tools[0].kind == "web_search" @@ -850,6 +860,7 @@ def test_model_resource_from_dict(self): "id": "gpt-4", } resource = ModelResource.from_dict(data) + assert isinstance(resource, ModelResource) assert resource.name == "my-model" assert resource.kind == "model" assert resource.id == "gpt-4" @@ -871,6 +882,7 @@ def test_tool_resource_from_dict(self): "id": "search-tool", } resource = ToolResource.from_dict(data) + assert isinstance(resource, ToolResource) assert resource.name == "my-tool" assert resource.kind == "tool" assert resource.id == "search-tool" @@ -927,7 +939,7 @@ def test_no_evaluation_without_equals_prefix(self): def test_none_value_returns_none(self): """Test that None values are returned as None.""" - assert _try_powerfx_eval(None) is None + assert _try_powerfx_eval(cast("str", None)) is None def test_empty_string_returns_empty(self): """Test that empty strings are returned as empty.""" diff --git a/python/packages/declarative/tests/test_default_http_request_handler.py b/python/packages/declarative/tests/test_default_http_request_handler.py index ecdce3d7ff0..93cfc9b6745 100644 --- a/python/packages/declarative/tests/test_default_http_request_handler.py +++ b/python/packages/declarative/tests/test_default_http_request_handler.py @@ -239,7 +239,7 @@ def counting_ctor(*args, **kwargs): # type: ignore[no-untyped-def] import agent_framework_declarative._workflows._http_handler as hh - hh.httpx.AsyncClient = counting_ctor # type: ignore[assignment] + hh.httpx.AsyncClient = counting_ctor # type: ignore[assignment] # ty: ignore[invalid-assignment] try: handler = DefaultHttpRequestHandler() try: diff --git a/python/packages/declarative/tests/test_default_mcp_tool_handler.py b/python/packages/declarative/tests/test_default_mcp_tool_handler.py index 58ce3144ee9..1523f2261f0 100644 --- a/python/packages/declarative/tests/test_default_mcp_tool_handler.py +++ b/python/packages/declarative/tests/test_default_mcp_tool_handler.py @@ -491,7 +491,9 @@ def boom(**_a: Any) -> Any: result = await handler.invoke_tool(inv) assert result.is_error is True assert result.error_message == "server says no" - assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue] + text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None + assert text.startswith("Error:") @pytest.mark.asyncio async def test_httpx_error_returns_error_result(self) -> None: @@ -536,7 +538,9 @@ async def test_connect_failure_returns_error_result(self) -> None: ): result = await handler.invoke_tool(_invocation()) assert result.is_error is True - assert result.outputs[0].text.startswith("Error:") # type: ignore[reportAttributeAccessIssue] + text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None + assert text.startswith("Error:") # Failed connect must clear in-flight + cache entries. assert handler._inflight == {} assert len(handler._cache) == 0 @@ -609,7 +613,7 @@ async def test_list_tools_returns_json_catalog(self) -> None: with _patch_tool(): # Prime the cache so the FakeTool session exists. await handler.invoke_tool(_invocation()) - FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] # ty: ignore[invalid-assignment] FakeListToolsResult( tools=[ FakeMcpTool( @@ -625,7 +629,9 @@ async def test_list_tools_returns_json_catalog(self) -> None: result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) assert result.is_error is False assert len(result.outputs) == 1 - payload = json.loads(result.outputs[0].text) # type: ignore[reportAttributeAccessIssue] + text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None + payload = json.loads(text) assert payload == { "tools": [ { @@ -649,11 +655,12 @@ async def test_list_tools_property_order_is_stable(self) -> None: handler = DefaultMCPToolHandler() with _patch_tool(): await handler.invoke_tool(_invocation()) - FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] # ty: ignore[invalid-assignment] FakeListToolsResult(tools=[FakeMcpTool(name="t1", description="d")]), ] result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None name_idx = text.find('"name"') desc_idx = text.find('"description"') input_idx = text.find('"inputSchema"') @@ -666,11 +673,12 @@ async def test_list_tools_indented_output(self) -> None: handler = DefaultMCPToolHandler() with _patch_tool(): await handler.invoke_tool(_invocation()) - FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] # ty: ignore[invalid-assignment] FakeListToolsResult(tools=[FakeMcpTool(name="t1")]), ] result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None # Indented output contains newlines and a 2-space indented key. assert "\n " in text @@ -704,13 +712,15 @@ async def test_list_tools_paginates(self) -> None: handler = DefaultMCPToolHandler() with _patch_tool(): await handler.invoke_tool(_invocation()) - FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] # ty: ignore[invalid-assignment] FakeListToolsResult(tools=[FakeMcpTool(name="a")], next_cursor="cursor1"), FakeListToolsResult(tools=[FakeMcpTool(name="b")], next_cursor="cursor2"), FakeListToolsResult(tools=[FakeMcpTool(name="c")], next_cursor=None), ] result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) - payload = json.loads(result.outputs[0].text) # type: ignore[reportAttributeAccessIssue] + text = result.outputs[0].text # type: ignore[reportAttributeAccessIssue] + assert text is not None + payload = json.loads(text) assert [t["name"] for t in payload["tools"]] == ["a", "b", "c"] session = FakeTool.instances[0].session assert session is not None @@ -736,7 +746,7 @@ async def test_list_tools_propagates_session_errors_as_error_result(self) -> Non handler = DefaultMCPToolHandler() with _patch_tool(): await handler.invoke_tool(_invocation()) - FakeTool.instances[0].session.list_tools_error = httpx.ReadTimeout("read timed out") # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_error = httpx.ReadTimeout("read timed out") # type: ignore[union-attr] # ty: ignore[invalid-assignment] result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) assert result.is_error is True assert "ReadTimeout" in (result.error_message or "") @@ -766,7 +776,7 @@ def fail(**_a: Any) -> Any: with _patch_tool(): await handler.invoke_tool(_invocation()) FakeTool.instances[0].call_handler = fail - FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] + FakeTool.instances[0].session.list_tools_pages = [ # type: ignore[union-attr] # ty: ignore[invalid-assignment] FakeListToolsResult(tools=[]), ] result = await handler.invoke_tool(_invocation(tool_name=DefaultMCPToolHandler.LIST_TOOLS_TOOL_NAME)) diff --git a/python/packages/declarative/tests/test_function_tool_executor.py b/python/packages/declarative/tests/test_function_tool_executor.py index f11b3568658..b9458f8658d 100644 --- a/python/packages/declarative/tests/test_function_tool_executor.py +++ b/python/packages/declarative/tests/test_function_tool_executor.py @@ -1087,7 +1087,7 @@ async def test_approval_response_approved(self, mock_state, mock_context): """When approval response is approved, the tool should be invoked.""" self._init_state(mock_state) - call_log = [] + call_log: list[int] = [] def my_tool(x: int) -> int: call_log.append(x) diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index f114c8f0ae9..dbd54b7103b 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -5,7 +5,7 @@ # pyright: reportGeneralTypeIssues=false from dataclasses import dataclass -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -141,7 +141,7 @@ async def test_get_custom_namespace(self, mock_state): # Set via direct state data manipulation to create custom namespace state_data = state.get_state_data() - state_data["Custom"] = {"myns": {"value": 42}} + cast(dict[str, Any], state_data)["Custom"] = {"myns": {"value": 42}} state.set_state_data(state_data) result = state.get("myns.value") @@ -253,10 +253,10 @@ async def test_eval_non_string_returns_as_is(self, mock_state): state.initialize() # Cast to Any to test the runtime behavior with non-string inputs - result = state.eval(42) # type: ignore[arg-type] + result = state.eval(42) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert result == 42 - result = state.eval([1, 2, 3]) # type: ignore[arg-type] + result = state.eval([1, 2, 3]) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert result == [1, 2, 3] @_requires_powerfx @@ -1127,7 +1127,7 @@ async def test_foreach_next_continues_iteration(self, mock_context, mock_state): # Set up loop state as ForeachInitExecutor would state_data = state.get_state_data() - state_data[LOOP_STATE_KEY] = { + cast(dict[str, Any], state_data)[LOOP_STATE_KEY] = { "foreach_init": { "items": ["a", "b", "c"], "index": 0, @@ -1238,7 +1238,7 @@ async def test_foreach_next_loop_complete(self, mock_context, mock_state): # Set up loop state at last item state_data = state.get_state_data() - state_data[LOOP_STATE_KEY] = { + cast(dict[str, Any], state_data)[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b"], "index": 1, # Already at last item @@ -1272,7 +1272,7 @@ async def test_foreach_next_handle_break_control(self, mock_context, mock_state) # Set up loop state state_data = state.get_state_data() - state_data[LOOP_STATE_KEY] = { + cast(dict[str, Any], state_data)[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b", "c"], "index": 0, @@ -1306,7 +1306,7 @@ async def test_foreach_next_handle_continue_control(self, mock_context, mock_sta # Set up loop state state_data = state.get_state_data() - state_data[LOOP_STATE_KEY] = { + cast(dict[str, Any], state_data)[LOOP_STATE_KEY] = { "loop_id": { "items": ["a", "b", "c"], "index": 0, @@ -2340,7 +2340,7 @@ def test_get_branch_exit_with_chain(self): exec3 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "3"}}, id="e3") # Simulate a chain by dynamically setting attribute - exec1._chain_executors = [exec1, exec2, exec3] # type: ignore[attr-defined] + exec1._chain_executors = [exec1, exec2, exec3] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] exit_exec = graph_builder._get_branch_exit(exec1) @@ -2378,7 +2378,7 @@ def test_get_branch_exit_returns_none_for_goto_terminator(self): ) # Simulate a single-action branch chain - goto_executor._chain_executors = [goto_executor] # type: ignore[attr-defined] + goto_executor._chain_executors = [goto_executor] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] exit_exec = graph_builder._get_branch_exit(goto_executor) assert exit_exec is None @@ -2395,7 +2395,7 @@ def test_get_branch_exit_returns_none_for_end_workflow_terminator(self): {"kind": "EndWorkflow", "id": "end"}, id="end", ) - end_executor._chain_executors = [end_executor] # type: ignore[attr-defined] + end_executor._chain_executors = [end_executor] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] exit_exec = graph_builder._get_branch_exit(end_executor) assert exit_exec is None @@ -2419,7 +2419,7 @@ def test_get_branch_exit_returns_none_for_goto_in_chain(self): {"kind": "GotoAction", "id": "goto_target", "actionId": "some_target"}, id="goto_target", ) - activity._chain_executors = [activity, goto] # type: ignore[attr-defined] + activity._chain_executors = [activity, goto] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] exit_exec = graph_builder._get_branch_exit(activity) assert exit_exec is None @@ -2434,7 +2434,7 @@ def test_get_branch_exit_returns_executor_for_non_terminator(self): exec1 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "1"}}, id="e1") exec2 = SendActivityExecutor({"kind": "SendActivity", "activity": {"text": "2"}}, id="e2") - exec1._chain_executors = [exec1, exec2] # type: ignore[attr-defined] + exec1._chain_executors = [exec1, exec2] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] exit_exec = graph_builder._get_branch_exit(exec1) assert exit_exec == exec2 diff --git a/python/packages/declarative/tests/test_graph_executors.py b/python/packages/declarative/tests/test_graph_executors.py index f6505bf4c89..885d3457c9f 100644 --- a/python/packages/declarative/tests/test_graph_executors.py +++ b/python/packages/declarative/tests/test_graph_executors.py @@ -2,7 +2,7 @@ """Tests for the graph-based declarative workflow executors.""" -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -1299,16 +1299,17 @@ def test_eval_raises_when_engine_unavailable(self): import agent_framework_declarative._workflows._declarative_base as base_mod mock_state = MagicMock() - mock_state._data: dict[str, Any] = {} - mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d)) - mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v)) + data: dict[str, Any] = {} + mock_state._data = data + mock_state.get = MagicMock(side_effect=lambda k, d=None: data.get(k, d)) + mock_state.set = MagicMock(side_effect=lambda k, v: data.__setitem__(k, v)) state = DeclarativeWorkflowState(mock_state) state.initialize({"name": "test"}) original_engine = base_mod.Engine try: - base_mod.Engine = None + base_mod.Engine = cast(Any, None) with pytest.raises(RuntimeError, match="PowerFx is not available"): state.eval("=Local.counter + 1") finally: @@ -1319,19 +1320,20 @@ def test_eval_passes_through_plain_strings_without_engine(self): import agent_framework_declarative._workflows._declarative_base as base_mod mock_state = MagicMock() - mock_state._data: dict[str, Any] = {} - mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d)) - mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v)) + data: dict[str, Any] = {} + mock_state._data = data + mock_state.get = MagicMock(side_effect=lambda k, d=None: data.get(k, d)) + mock_state.set = MagicMock(side_effect=lambda k, v: data.__setitem__(k, v)) state = DeclarativeWorkflowState(mock_state) state.initialize() original_engine = base_mod.Engine try: - base_mod.Engine = None + base_mod.Engine = cast(Any, None) assert state.eval("hello world") == "hello world" assert state.eval("") == "" - assert state.eval(42) == 42 + assert state.eval(cast("str", 42)) == 42 finally: base_mod.Engine = original_engine diff --git a/python/packages/declarative/tests/test_http_request_yaml_integration.py b/python/packages/declarative/tests/test_http_request_yaml_integration.py index 49cd0d15e83..7d5ac96d476 100644 --- a/python/packages/declarative/tests/test_http_request_yaml_integration.py +++ b/python/packages/declarative/tests/test_http_request_yaml_integration.py @@ -74,7 +74,7 @@ async def test_http_request_yaml_roundtrip() -> None: await workflow.run({}) decl: dict[str, Any] = workflow._state.get(DECLARATIVE_STATE_KEY) or {} - local = decl.get("Local") or {} + local: dict[str, Any] = decl.get("Local") or {} assert local.get("RepoOwner") == "dotnet" repo_info = local.get("RepoInfo") diff --git a/python/packages/declarative/tests/test_powerfx_functions.py b/python/packages/declarative/tests/test_powerfx_functions.py index 6c9752f79c7..a0b0ce1e0a7 100644 --- a/python/packages/declarative/tests/test_powerfx_functions.py +++ b/python/packages/declarative/tests/test_powerfx_functions.py @@ -2,6 +2,8 @@ """Tests for custom PowerFx-like functions.""" +from typing import cast + from agent_framework_declarative._workflows._powerfx_functions import ( CUSTOM_FUNCTIONS, assistant_message, @@ -59,7 +61,7 @@ def test_user_message_creates_dict(self): def test_user_message_with_none(self): """Test UserMessage with None.""" - msg = user_message(None) + msg = user_message(cast("str", None)) assert msg == {"role": "user", "content": ""} @@ -354,7 +356,7 @@ def test_agent_message_with_none(self): """Test AgentMessage with None.""" from agent_framework_declarative._workflows._powerfx_functions import agent_message - msg = agent_message(None) + msg = agent_message(cast("str", None)) assert msg == {"role": "assistant", "content": ""} diff --git a/python/packages/declarative/tests/test_workflow_factory.py b/python/packages/declarative/tests/test_workflow_factory.py index f163cd18f0b..ba54e50e8ea 100644 --- a/python/packages/declarative/tests/test_workflow_factory.py +++ b/python/packages/declarative/tests/test_workflow_factory.py @@ -2,6 +2,8 @@ """Unit tests for WorkflowFactory.""" +from typing import Any, cast + import pytest from agent_framework_declarative._workflows._errors import DeclarativeWorkflowError @@ -327,7 +329,7 @@ class MockAgent: name = "mock-agent" factory = WorkflowFactory() - factory.register_agent("myAgent", MockAgent()) + factory.register_agent("myAgent", cast(Any, MockAgent())) assert "myAgent" in factory._agents @@ -1024,7 +1026,7 @@ def test_agent_creation_with_file_reference(self, tmp_path): workflow = factory.create_workflow_from_yaml_path(workflow_file) assert workflow is not None - assert "TestAgent" in workflow._declarative_agents + assert "TestAgent" in cast(Any, workflow)._declarative_agents def test_agent_connection_definition_raises(self): """Test that connection-based agent definition raises error.""" @@ -1062,7 +1064,7 @@ def test_preregistered_agent_not_overwritten(self): class MockAgent: name = "PreregisteredAgent" - factory = WorkflowFactory(agents={"TestAgent": MockAgent()}) + factory = WorkflowFactory(agents={"TestAgent": cast(Any, MockAgent())}) workflow = factory.create_workflow_from_yaml(""" kind: Workflow agents: @@ -1075,7 +1077,7 @@ class MockAgent: value: 1 """) - assert workflow._declarative_agents["TestAgent"].name == "PreregisteredAgent" + assert cast(Any, workflow)._declarative_agents["TestAgent"].name == "PreregisteredAgent" class TestWorkflowFactoryInputSchema: @@ -1099,7 +1101,7 @@ def test_inputs_to_json_schema_basic(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert schema["type"] == "object" assert "name" in schema["properties"] assert "age" in schema["properties"] @@ -1126,7 +1128,7 @@ def test_inputs_schema_with_optional_field(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert "required_field" in schema["required"] assert "optional_field" not in schema["required"] @@ -1145,7 +1147,7 @@ def test_inputs_schema_with_default_value(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert schema["properties"]["greeting"]["default"] == "Hello" def test_inputs_schema_with_enum(self): @@ -1166,7 +1168,7 @@ def test_inputs_schema_with_enum(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert schema["properties"]["color"]["enum"] == ["red", "green", "blue"] def test_inputs_schema_type_mappings(self): @@ -1193,7 +1195,7 @@ def test_inputs_schema_type_mappings(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert schema["properties"]["str_field"]["type"] == "string" assert schema["properties"]["int_field"]["type"] == "integer" assert schema["properties"]["float_field"]["type"] == "number" @@ -1215,7 +1217,7 @@ def test_inputs_schema_simple_format(self): value: 1 """) - schema = workflow.input_schema + schema = cast(Any, workflow).input_schema assert schema["properties"]["name"]["type"] == "string" assert schema["properties"]["count"]["type"] == "integer" assert "name" in schema["required"] @@ -1234,7 +1236,11 @@ class MockAgent1: class MockAgent2: name = "Agent2" - factory = WorkflowFactory().register_agent("agent1", MockAgent1()).register_agent("agent2", MockAgent2()) + factory = ( + WorkflowFactory() + .register_agent("agent1", cast(Any, MockAgent1())) + .register_agent("agent2", cast(Any, MockAgent2())) + ) assert "agent1" in factory._agents assert "agent2" in factory._agents @@ -1267,7 +1273,7 @@ def my_binding(): factory = ( WorkflowFactory() - .register_agent("agent", MockAgent()) + .register_agent("agent", cast(Any, MockAgent())) .register_tool("tool", my_tool) .register_binding("binding", my_binding) ) diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 243f36cee38..9ad8534e184 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -314,7 +314,7 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> text_obj = first_content.get("text", "") text = text_obj if isinstance(text_obj, str) else str(text_obj) - chat_msg = Message(role=role, contents=[text]) # type: ignore[arg-type] + chat_msg = Message(role=role, contents=[text]) chat_messages.append(chat_msg) # Add messages to internal storage @@ -466,7 +466,7 @@ async def list_items( message = OpenAIMessage( id=item_id, type="message", - role=role, # type: ignore + role=role, content=message_contents, # type: ignore status="completed", ) diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index f6a52ae945c..50eb03b6a1b 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -543,7 +543,7 @@ def _serialize_value(self, value: Any) -> Any: # Handle SerializationMixin (like Message) - call to_dict() if hasattr(value, "to_dict") and callable(getattr(value, "to_dict", None)): try: - return value.to_dict() # type: ignore[attr-defined, no-any-return] + return value.to_dict() except Exception as e: logger.debug(f"Failed to serialize with to_dict(): {e}") return str(value) @@ -551,7 +551,7 @@ def _serialize_value(self, value: Any) -> Any: # Handle Pydantic models - call model_dump() if hasattr(value, "model_dump") and callable(getattr(value, "model_dump", None)): try: - return value.model_dump() # type: ignore[attr-defined, no-any-return] + return value.model_dump() except Exception as e: logger.debug(f"Failed to serialize Pydantic model: {e}") return str(value) @@ -606,7 +606,7 @@ def _serialize_request_data(self, request_data: Any) -> dict[str, Any]: logger.debug(f"Failed to serialize dataclass fields: {e}") # Fallback to asdict() if our custom serialization fails try: - return asdict(request_data) # type: ignore[arg-type] + return asdict(request_data) except Exception as e2: logger.debug(f"Failed to serialize dataclass with asdict(): {e2}") diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index c4138a0aeb4..18a90e44a3e 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -566,12 +566,12 @@ async def get_entity_info(entity_id: str) -> EntityInfo: workflow_dump: dict[str, Any] | str | None = None if hasattr(entity_obj, "to_dict") and callable(getattr(entity_obj, "to_dict", None)): try: - workflow_dump = entity_obj.to_dict() # type: ignore[attr-defined] + workflow_dump = entity_obj.to_dict() except Exception: workflow_dump = None elif hasattr(entity_obj, "to_json") and callable(getattr(entity_obj, "to_json", None)): try: - raw_dump = entity_obj.to_json() # type: ignore[attr-defined] + raw_dump = entity_obj.to_json() except Exception: workflow_dump = None else: @@ -1254,16 +1254,16 @@ async def _stream_execution( # IMPORTANT: Check model_dump_json FIRST because to_json() can have newlines (pretty-printing) # which breaks SSE format. model_dump_json() returns single-line JSON. if hasattr(event, "model_dump_json"): - payload = event.model_dump_json() # type: ignore[attr-defined] + payload = event.model_dump_json() elif hasattr(event, "to_json") and callable(getattr(event, "to_json", None)): - payload = event.to_json() # type: ignore[attr-defined] + payload = event.to_json() # Strip newlines from pretty-printed JSON for SSE compatibility payload = payload.replace("\n", "").replace("\r", "") elif isinstance(event, dict): # Handle plain dict events (e.g., error events from executor) payload = json.dumps(event) elif hasattr(event, "to_dict") and callable(getattr(event, "to_dict", None)): - payload = json.dumps(event.to_dict()) # type: ignore[attr-defined] + payload = json.dumps(event.to_dict()) else: payload = json.dumps(str(event)) yield f"data: {payload}\n\n" @@ -1324,7 +1324,7 @@ async def _stream_openai_execution( # OpenAI SDK events have model_dump_json() - use it for single-line JSON if hasattr(event, "model_dump_json"): - payload = event.model_dump_json() # type: ignore[attr-defined] + payload = event.model_dump_json() yield f"data: {payload}\n\n" else: # Fallback (shouldn't happen with OpenAI SDK) diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 602901995d0..f7a2f2f44b3 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -123,7 +123,7 @@ def extract_executor_message_types(executor: Any) -> list[Any]: try: handlers = executor._handlers if isinstance(handlers, dict): - message_types = list(handlers.keys()) # type: ignore[arg-type] # pyright: ignore[reportUnknownArgumentType] + message_types = list(handlers.keys()) # type: ignore[arg-type] except Exception as exc: # pragma: no cover - defensive logging path logger.debug(f"Failed to read executor handlers: {exc}") @@ -558,7 +558,7 @@ def _parse_string_input(input_str: str, target_type: type) -> Any: common_fields = ["text", "message", "content", "input", "data"] for field in common_fields: try: - return target_type(**{field: input_str}) # type: ignore + return target_type(**{field: input_str}) except Exception as e: logger.debug(f"Failed to parse string input with field '{field}': {e}") continue @@ -581,7 +581,7 @@ def _parse_string_input(input_str: str, target_type: type) -> Any: data = json.loads(input_str) if hasattr(target_type, "from_dict"): return target_type.from_dict(data) # type: ignore - return target_type(**data) # type: ignore + return target_type(**data) # Try other common fields common_fields = ["text", "message", "content"] @@ -590,7 +590,7 @@ def _parse_string_input(input_str: str, target_type: type) -> Any: for field in common_fields: if field in params: try: - return target_type(**{field: input_str}) # type: ignore + return target_type(**{field: input_str}) except Exception as e: logger.debug(f"Failed to create SerializationMixin with field '{field}': {e}") continue @@ -603,13 +603,13 @@ def _parse_string_input(input_str: str, target_type: type) -> Any: # Try parsing as JSON if input_str.strip().startswith("{"): data = json.loads(input_str) - return target_type(**data) # type: ignore + return target_type(**data) # Try common field names common_fields = ["text", "message", "content", "input", "data"] for field in common_fields: try: - return target_type(**{field: input_str}) # type: ignore + return target_type(**{field: input_str}) except Exception as e: logger.debug(f"Failed to create dataclass with field '{field}': {e}") continue @@ -639,12 +639,12 @@ def _parse_dict_input(input_dict: dict[str, Any], target_type: type) -> Any: # Try "input" field first (common for workflow inputs) if "input" in input_dict: - return target_type(input_dict["input"]) # type: ignore + return target_type(input_dict["input"]) # If single-key dict, extract the value if len(input_dict) == 1: value = next(iter(input_dict.values())) - return target_type(value) # type: ignore + return target_type(value) # Otherwise, return as-is return input_dict @@ -670,14 +670,14 @@ def _parse_dict_input(input_dict: dict[str, Any], target_type: type) -> Any: return _build_message_from_legacy_payload(input_dict) if hasattr(target_type, "from_dict"): return target_type.from_dict(input_dict) # type: ignore - return target_type(**input_dict) # type: ignore + return target_type(**input_dict) except Exception as e: logger.debug(f"Failed to parse dict as SerializationMixin: {e}") # Dataclasses if is_dataclass(target_type): try: - return target_type(**input_dict) # type: ignore + return target_type(**input_dict) except Exception as e: logger.debug(f"Failed to parse dict as dataclass: {e}") diff --git a/python/packages/devui/tests/devui/capture_messages.py b/python/packages/devui/tests/devui/capture_messages.py index 820fedde576..ff73d320b72 100644 --- a/python/packages/devui/tests/devui/capture_messages.py +++ b/python/packages/devui/tests/devui/capture_messages.py @@ -13,6 +13,7 @@ import logging import threading import time +from collections.abc import Mapping from pathlib import Path from typing import Any @@ -94,14 +95,16 @@ def capture_agent_stream_with_tracing(client: OpenAI, agent_id: str, scenario: s stream=True, ) - events = [] + events: list[dict[str, Any]] = [] for event in stream: # Serialize the entire event object try: - event_dict = json.loads(event.model_dump_json()) + raw_event_dict = json.loads(event.model_dump_json()) except Exception: # Fallback to dict conversion if model_dump_json fails - event_dict = event.__dict__ if hasattr(event, "__dict__") else str(event) + raw_event_dict = event.__dict__ if hasattr(event, "__dict__") else {"event": str(event)} + + event_dict = dict(raw_event_dict) if isinstance(raw_event_dict, Mapping) else {"event": str(event)} events.append(event_dict) @@ -138,14 +141,16 @@ def capture_workflow_stream_with_tracing( stream=True, ) - events = [] + events: list[dict[str, Any]] = [] for event in stream: # Serialize the entire event object try: - event_dict = json.loads(event.model_dump_json()) + raw_event_dict = json.loads(event.model_dump_json()) except Exception: # Fallback to dict conversion if model_dump_json fails - event_dict = event.__dict__ if hasattr(event, "__dict__") else str(event) + raw_event_dict = event.__dict__ if hasattr(event, "__dict__") else {"event": str(event)} + + event_dict = dict(raw_event_dict) if isinstance(raw_event_dict, Mapping) else {"event": str(event)} events.append(event_dict) diff --git a/python/packages/devui/tests/devui/conftest.py b/python/packages/devui/tests/devui/conftest.py index 114a7a7d6d6..9575dfbc270 100644 --- a/python/packages/devui/tests/devui/conftest.py +++ b/python/packages/devui/tests/devui/conftest.py @@ -12,7 +12,7 @@ import sys from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from pathlib import Path -from typing import Any, Generic +from typing import Any, Generic, TypedDict # noqa: F401 import pytest import pytest_asyncio diff --git a/python/packages/devui/tests/devui/test_approval_validation.py b/python/packages/devui/tests/devui/test_approval_validation.py index 2ef8f538728..d8de7933103 100644 --- a/python/packages/devui/tests/devui/test_approval_validation.py +++ b/python/packages/devui/tests/devui/test_approval_validation.py @@ -151,7 +151,9 @@ def test_valid_approval_accepted_with_server_data(executor: AgentFrameworkExecut # Verify SERVER-STORED data is used, not the client's forged data assert approval.function_call.name == "safe_tool" assert approval.function_call.call_id == "call_server" - fc_args = approval.function_call.parse_arguments() if hasattr(approval.function_call, "parse_arguments") else {} + fc_args: dict[str, Any] = ( + approval.function_call.parse_arguments() if hasattr(approval.function_call, "parse_arguments") else {} + ) assert fc_args.get("key") == "server_value" diff --git a/python/packages/devui/tests/devui/test_checkpoints.py b/python/packages/devui/tests/devui/test_checkpoints.py index 1e87187e415..f6f6e15c5cc 100644 --- a/python/packages/devui/tests/devui/test_checkpoints.py +++ b/python/packages/devui/tests/devui/test_checkpoints.py @@ -242,7 +242,9 @@ async def test_checkpoints_appear_as_conversation_items(self, checkpoint_manager assert item.get("checkpoint_id") in checkpoint_ids assert item.get("workflow_name") == test_workflow.name assert "timestamp" in item - assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id} + item_id = item.get("id") + assert isinstance(item_id, str) + assert item_id.startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id} @pytest.mark.asyncio async def test_load_checkpoint_from_session(self, checkpoint_manager, test_workflow): diff --git a/python/packages/devui/tests/devui/test_cleanup_hooks.py b/python/packages/devui/tests/devui/test_cleanup_hooks.py index 5336e98d415..22b762e35d6 100644 --- a/python/packages/devui/tests/devui/test_cleanup_hooks.py +++ b/python/packages/devui/tests/devui/test_cleanup_hooks.py @@ -320,7 +320,7 @@ async def _stream(): async def test_cleanup_execution_order(): """Test that cleanup hooks execute in registration order.""" agent = MockAgent("OrderTest") - execution_order = [] + execution_order: list[int] = [] def hook1(): execution_order.append(1) @@ -353,7 +353,7 @@ async def test_custom_cleanup_logic(): """Test registering custom cleanup function with complex logic.""" agent = MockAgent("CustomCleanup") cleanup_executed = False - resources_closed = [] + resources_closed: list[str] = [] async def custom_cleanup(): nonlocal cleanup_executed diff --git a/python/packages/devui/tests/devui/test_conversations.py b/python/packages/devui/tests/devui/test_conversations.py index a9e7ac6441c..bd999a81cef 100644 --- a/python/packages/devui/tests/devui/test_conversations.py +++ b/python/packages/devui/tests/devui/test_conversations.py @@ -5,7 +5,7 @@ from typing import cast import pytest -from openai.types.conversations import InputFileContent, InputImageContent, InputTextContent +from openai.types.conversations import InputTextContent from agent_framework_devui._conversations import InMemoryConversationStore @@ -312,10 +312,10 @@ async def test_list_items_handles_images_and_files(): assert text_content.text == "Check this image and PDF" assert items[0].content[1].type == "input_image" - image_content = cast(InputImageContent, items[0].content[1]) + image_content = items[0].content[1] assert image_content.image_url == "data:image/png;base64,iVBORw0KGgo=" assert image_content.detail == "auto" assert items[0].content[2].type == "input_file" - file_content = cast(InputFileContent, items[0].content[2]) + file_content = items[0].content[2] assert file_content.file_url == "data:application/pdf;base64,JVBERi0=" diff --git a/python/packages/devui/tests/devui/test_discovery.py b/python/packages/devui/tests/devui/test_discovery.py index 4a4efaadabf..f7cd22e7be5 100644 --- a/python/packages/devui/tests/devui/test_discovery.py +++ b/python/packages/devui/tests/devui/test_discovery.py @@ -113,6 +113,7 @@ def create_session(self, **kwargs): # Now check enriched metadata after loading enriched = discovery.get_entity_info(entity.id) + assert enriched is not None assert enriched.type == "agent" # Now correctly identified assert enriched.name == "Non-Streaming Agent" diff --git a/python/packages/devui/tests/devui/test_execution.py b/python/packages/devui/tests/devui/test_execution.py index fc3abee80d0..f1c03c1728e 100644 --- a/python/packages/devui/tests/devui/test_execution.py +++ b/python/packages/devui/tests/devui/test_execution.py @@ -18,7 +18,7 @@ from agent_framework import Agent, AgentExecutor, FunctionExecutor, WorkflowBuilder # Import mock classes from conftest for direct use in some tests -from conftest import MockBaseChatClient +from conftest import MockBaseChatClient # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from agent_framework_devui._discovery import EntityDiscovery from agent_framework_devui._executor import AgentFrameworkExecutor, EntityNotFoundError @@ -43,7 +43,7 @@ async def executor(test_entities_dir): return executor -async def test_executor_entity_discovery(executor): +async def test_executor_entity_discovery(executor: AgentFrameworkExecutor) -> None: """Test executor entity discovery.""" entities = await executor.discover_entities() @@ -65,7 +65,7 @@ async def test_executor_entity_discovery(executor): ) -async def test_executor_get_entity_info(executor): +async def test_executor_get_entity_info(executor: AgentFrameworkExecutor) -> None: """Test getting entity info by ID.""" entities = await executor.discover_entities() entity_id = entities[0].id @@ -371,7 +371,7 @@ async def test_full_pipeline_workflow_events_are_json_serializable(): assert final_response is not None -async def test_get_entity_info_raises_for_invalid_id(executor): +async def test_get_entity_info_raises_for_invalid_id(executor: AgentFrameworkExecutor) -> None: """Test that get_entity_info raises EntityNotFoundError for invalid ID.""" with pytest.raises(EntityNotFoundError): executor.get_entity_info("nonexistent_agent") diff --git a/python/packages/devui/tests/devui/test_mapper.py b/python/packages/devui/tests/devui/test_mapper.py index 3ff4492d807..4c9ae7d77a5 100644 --- a/python/packages/devui/tests/devui/test_mapper.py +++ b/python/packages/devui/tests/devui/test_mapper.py @@ -24,12 +24,12 @@ ) # Import factory functions from conftest for parameterized test data creation -from conftest import ( +from conftest import ( # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] create_agent_run_response, create_executor_completed_event, create_executor_failed_event, create_executor_invoked_event, -) +) # pyrefly: ignore[missing-import] from agent_framework_devui._mapper import MessageMapper from agent_framework_devui.models._openai_custom import ( diff --git a/python/packages/devui/tests/devui/test_openai_sdk_integration.py b/python/packages/devui/tests/devui/test_openai_sdk_integration.py index 0de40756906..d1b4d41f34f 100644 --- a/python/packages/devui/tests/devui/test_openai_sdk_integration.py +++ b/python/packages/devui/tests/devui/test_openai_sdk_integration.py @@ -10,6 +10,7 @@ import time from collections.abc import Generator from pathlib import Path +from typing import Any, cast from urllib.parse import urlparse import pytest @@ -103,11 +104,12 @@ def test_openai_sdk_responses_create_with_entity_id(devui_server: str) -> None: # Get available entities - extract host and port from base_url parsed = urlparse(base_url) + assert parsed.hostname is not None conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=10) try: conn.request("GET", "/v1/entities") - response = conn.getresponse() - entities = json.loads(response.read().decode("utf-8"))["entities"] + http_response = conn.getresponse() + entities = json.loads(http_response.read().decode("utf-8"))["entities"] finally: conn.close() @@ -128,7 +130,7 @@ def test_openai_sdk_responses_create_with_entity_id(devui_server: str) -> None: assert response.object == "response" assert len(response.output) > 0 - assert response.output[0].content is not None + assert cast(Any, response.output[0]).content is not None def test_openai_sdk_responses_create_streaming(devui_server: str) -> None: @@ -138,6 +140,7 @@ def test_openai_sdk_responses_create_streaming(devui_server: str) -> None: # Get available entities - extract host and port from base_url parsed = urlparse(base_url) + assert parsed.hostname is not None conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=10) try: conn.request("GET", "/v1/entities") @@ -183,6 +186,7 @@ def test_openai_sdk_with_conversations(devui_server: str) -> None: # Get available entities - extract host and port from base_url parsed = urlparse(base_url) + assert parsed.hostname is not None conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=10) try: conn.request("GET", "/v1/entities") @@ -226,7 +230,7 @@ def test_openai_sdk_with_conversations(devui_server: str) -> None: assert len(response2.output) > 0 # The agent should remember the name from the previous turn # Note: This may not work with all agents, so we just verify we got a response - assert response2.output[0].content is not None + assert cast(Any, response2.output[0]).content is not None def test_openai_sdk_with_model_and_entity_id(devui_server: str) -> None: @@ -236,11 +240,12 @@ def test_openai_sdk_with_model_and_entity_id(devui_server: str) -> None: # Get available entities - extract host and port from base_url parsed = urlparse(base_url) + assert parsed.hostname is not None conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=10) try: conn.request("GET", "/v1/entities") - response = conn.getresponse() - entities = json.loads(response.read().decode("utf-8"))["entities"] + http_response = conn.getresponse() + entities = json.loads(http_response.read().decode("utf-8"))["entities"] finally: conn.close() diff --git a/python/packages/devui/tests/devui/test_server.py b/python/packages/devui/tests/devui/test_server.py index 76589216cb5..450cf9eac23 100644 --- a/python/packages/devui/tests/devui/test_server.py +++ b/python/packages/devui/tests/devui/test_server.py @@ -10,7 +10,7 @@ from typing import Any import pytest -from conftest import MockAgent +from conftest import MockAgent # pyrefly: ignore[missing-import] # pyright: ignore[reportMissingImports] from fastapi.testclient import TestClient import agent_framework_devui @@ -586,7 +586,7 @@ def test_serve_allows_non_loopback_with_explicit_token(monkeypatch): import uvicorn monkeypatch.delenv("DEVUI_AUTH_TOKEN", raising=False) - run_args = {} + run_args: dict[str, int | str] = {} def fake_run(_app, *, host, port, **_kwargs): run_args["host"] = host diff --git a/python/packages/devui/tests/devui/test_ui_memory_regression.py b/python/packages/devui/tests/devui/test_ui_memory_regression.py index 3fbf0e9db9c..ec450a94c80 100644 --- a/python/packages/devui/tests/devui/test_ui_memory_regression.py +++ b/python/packages/devui/tests/devui/test_ui_memory_regression.py @@ -359,7 +359,7 @@ def _parse_windows_process_rows(output: str) -> list[_BrowserProcessRow]: parent_pid = item.get("parent_pid") rss_kb = item.get("rss_kb") command = item.get("command") - if not all(isinstance(value, int) for value in (pid, parent_pid, rss_kb)): + if not isinstance(pid, int) or not isinstance(parent_pid, int) or not isinstance(rss_kb, int): continue if not isinstance(command, str): continue diff --git a/python/packages/durabletask/agent_framework_durabletask/_entities.py b/python/packages/durabletask/agent_framework_durabletask/_entities.py index 15fb77285e6..b86f2594145 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_entities.py +++ b/python/packages/durabletask/agent_framework_durabletask/_entities.py @@ -232,7 +232,7 @@ async def _invoke_agent( stream_candidate = await stream_candidate return await self._consume_stream( - stream=stream_candidate, # type: ignore[arg-type] + stream=stream_candidate, callback_context=callback_context, ) except TypeError as type_error: diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 713c1b4e69f..31c285c8088 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -513,11 +513,11 @@ def run_durable_agent( # Create a pre-completed task with acceptance response acceptance_response = self._create_acceptance_response(run_request.correlation_id) - entity_task: CompletableTask[AgentResponse] = CompletableTask() # type: ignore[no-untyped-call] + entity_task: CompletableTask[AgentResponse] = CompletableTask() entity_task.complete(acceptance_response) else: # Blocking mode: call entity and wait for response - entity_task = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore + entity_task = self._context.call_entity(entity_id, "run", run_request.to_dict()) # Wrap in DurableAgentTask for response transformation return DurableAgentTask( diff --git a/python/packages/durabletask/agent_framework_durabletask/_worker.py b/python/packages/durabletask/agent_framework_durabletask/_worker.py index 728ae17629a..7951b995adc 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_worker.py +++ b/python/packages/durabletask/agent_framework_durabletask/_worker.py @@ -120,7 +120,7 @@ def start(self) -> None: The worker will block until stopped. """ logger.info("[DurableAIAgentWorker] Starting worker with %d registered agents", len(self._registered_agents)) - self._worker.start() # type: ignore[no-untyped-call] + self._worker.start() def stop(self) -> None: """Stop the worker gracefully. @@ -129,7 +129,7 @@ def stop(self) -> None: This method delegates to the underlying worker's stop method. """ logger.info("[DurableAIAgentWorker] Stopping worker") - self._worker.stop() # type: ignore[no-untyped-call] + self._worker.stop() @property def registered_agent_names(self) -> list[str]: diff --git a/python/packages/durabletask/tests/integration_tests/conftest.py b/python/packages/durabletask/tests/integration_tests/conftest.py index 65202e29079..fd0baefa74d 100644 --- a/python/packages/durabletask/tests/integration_tests/conftest.py +++ b/python/packages/durabletask/tests/integration_tests/conftest.py @@ -12,7 +12,7 @@ import uuid from collections.abc import Generator from pathlib import Path -from typing import Any, cast +from typing import Any, Protocol, cast from urllib.parse import urlparse import pytest @@ -30,6 +30,13 @@ logging.basicConfig(level=logging.WARNING) +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[DurableTaskSchedulerClient, DurableAIAgentClient]: ... + + # ============================================================================= # Environment and Service Checks # ============================================================================= @@ -472,7 +479,7 @@ def orchestration_helper(worker_process: dict[str, Any]) -> OrchestrationHelper: @pytest.fixture(scope="module") -def agent_client_factory(worker_process: dict[str, Any]) -> type: +def agent_client_factory(worker_process: dict[str, Any]) -> type[AgentClientFactoryProtocol]: """Return a factory class for creating agent clients. Usage in tests: diff --git a/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py b/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py index 736c7022e8d..98171fe774b 100644 --- a/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py +++ b/python/packages/durabletask/tests/integration_tests/test_01_dt_single_agent.py @@ -10,8 +10,20 @@ - Empty thread ID handling """ +from typing import Any, Protocol + import pytest +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Module-level markers - applied to all tests in this module pytestmark = [ pytest.mark.flaky, @@ -27,7 +39,7 @@ class TestSingleAgent: """Test suite for single agent functionality.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol]) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture _, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py b/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py index 2dad4e6b7a0..f5d61b2526a 100644 --- a/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py +++ b/python/packages/durabletask/tests/integration_tests/test_02_dt_multi_agent.py @@ -10,8 +10,20 @@ - Agent isolation and tool routing """ +from typing import Any, Protocol + import pytest +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Agent names from the 02_multi_agent sample WEATHER_AGENT_NAME: str = "WeatherAgent" MATH_AGENT_NAME: str = "MathAgent" @@ -31,7 +43,7 @@ class TestMultiAgent: """Test suite for multi-agent functionality.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol]) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture _, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py b/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py index 1e311ecdcea..f3af85f4bec 100644 --- a/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py +++ b/python/packages/durabletask/tests/integration_tests/test_03_dt_single_agent_streaming.py @@ -22,15 +22,28 @@ import time from datetime import timedelta from pathlib import Path +from typing import Any, Protocol import pytest import redis.asyncio as aioredis +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Add sample directory to path to import RedisStreamResponseHandler SAMPLE_DIR = Path(__file__).parents[4] / "samples" / "04-hosting" / "durabletask" / "03_single_agent_streaming" sys.path.insert(0, str(SAMPLE_DIR)) -from redis_stream_response_handler import RedisStreamResponseHandler # type: ignore[reportMissingImports] # noqa: E402 +from redis_stream_response_handler import ( # type: ignore[reportMissingImports] # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] # noqa: E402 + RedisStreamResponseHandler, +) # Module-level markers - applied to all tests in this file pytestmark = [ @@ -48,7 +61,7 @@ class TestSampleReliableStreaming: """Tests for 03_single_agent_streaming sample.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type, orchestration_helper) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol], orchestration_helper) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture _, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_04_dt_single_agent_orchestration_chaining.py b/python/packages/durabletask/tests/integration_tests/test_04_dt_single_agent_orchestration_chaining.py index de70f9eadad..2b885bb0746 100644 --- a/python/packages/durabletask/tests/integration_tests/test_04_dt_single_agent_orchestration_chaining.py +++ b/python/packages/durabletask/tests/integration_tests/test_04_dt_single_agent_orchestration_chaining.py @@ -11,10 +11,21 @@ import json import logging +from typing import Any, Protocol import pytest from durabletask.client import OrchestrationStatus +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Agent name from the 04_single_agent_orchestration_chaining sample WRITER_AGENT_NAME: str = "WriterAgent" @@ -36,7 +47,7 @@ class TestSingleAgentOrchestrationChaining: """Test suite for single agent orchestration with chaining.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type, orchestration_helper) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol], orchestration_helper) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture self.dts_client, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_05_dt_multi_agent_orchestration_concurrency.py b/python/packages/durabletask/tests/integration_tests/test_05_dt_multi_agent_orchestration_concurrency.py index 88fe96487e1..383f0998cf9 100644 --- a/python/packages/durabletask/tests/integration_tests/test_05_dt_multi_agent_orchestration_concurrency.py +++ b/python/packages/durabletask/tests/integration_tests/test_05_dt_multi_agent_orchestration_concurrency.py @@ -11,10 +11,21 @@ import json import logging +from typing import Any, Protocol import pytest from durabletask.client import OrchestrationStatus +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Agent names from the 05_multi_agent_orchestration_concurrency sample PHYSICIST_AGENT_NAME: str = "PhysicistAgent" CHEMIST_AGENT_NAME: str = "ChemistAgent" @@ -36,7 +47,7 @@ class TestMultiAgentOrchestrationConcurrency: """Test suite for multi-agent orchestration with concurrency.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type, orchestration_helper) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol], orchestration_helper) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture self.dts_client, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_06_dt_multi_agent_orchestration_conditionals.py b/python/packages/durabletask/tests/integration_tests/test_06_dt_multi_agent_orchestration_conditionals.py index 177f4ca5f44..39421df9488 100644 --- a/python/packages/durabletask/tests/integration_tests/test_06_dt_multi_agent_orchestration_conditionals.py +++ b/python/packages/durabletask/tests/integration_tests/test_06_dt_multi_agent_orchestration_conditionals.py @@ -11,10 +11,21 @@ """ import logging +from typing import Any, Protocol import pytest from durabletask.client import OrchestrationStatus +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Agent names from the 06_multi_agent_orchestration_conditionals sample SPAM_AGENT_NAME: str = "SpamDetectionAgent" EMAIL_AGENT_NAME: str = "EmailAssistantAgent" @@ -36,7 +47,7 @@ class TestMultiAgentOrchestrationConditionals: """Test suite for multi-agent orchestration with conditionals.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type, orchestration_helper) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol], orchestration_helper) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture self.dts_client, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/integration_tests/test_07_dt_single_agent_orchestration_hitl.py b/python/packages/durabletask/tests/integration_tests/test_07_dt_single_agent_orchestration_hitl.py index 2d4a07a98f6..8c90d07e98d 100644 --- a/python/packages/durabletask/tests/integration_tests/test_07_dt_single_agent_orchestration_hitl.py +++ b/python/packages/durabletask/tests/integration_tests/test_07_dt_single_agent_orchestration_hitl.py @@ -11,10 +11,21 @@ """ import logging +from typing import Any, Protocol import pytest from durabletask.client import OrchestrationStatus +from agent_framework_durabletask import DurableAIAgentClient + + +class AgentClientFactoryProtocol(Protocol): + """Protocol for the agent client factory fixture.""" + + @classmethod + def create(cls, max_poll_retries: int = 90) -> tuple[Any, DurableAIAgentClient]: ... + + # Constants from the 07_single_agent_orchestration_hitl sample WRITER_AGENT_NAME: str = "WriterAgent" HUMAN_APPROVAL_EVENT: str = "HumanApproval" @@ -36,7 +47,7 @@ class TestSingleAgentOrchestrationHITL: """Test suite for single agent orchestration with human-in-the-loop.""" @pytest.fixture(autouse=True) - def setup(self, agent_client_factory: type, orchestration_helper) -> None: + def setup(self, agent_client_factory: type[AgentClientFactoryProtocol], orchestration_helper) -> None: """Setup test fixtures.""" # Create agent client using the factory fixture self.dts_client, self.agent_client = agent_client_factory.create() diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index a056d4e2549..7dc883edc7e 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -46,7 +46,7 @@ def test_get_agent_returns_durable_agent_shim(self, agent_client: DurableAIAgent agent = agent_client.get_agent("assistant") assert isinstance(agent, DurableAIAgent) - assert isinstance(agent, SupportsAgentRun) + assert isinstance(agent, SupportsAgentRun) # pyrefly: ignore[unsafe-overlap] def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None: """Verify retrieved agent has the correct name.""" diff --git a/python/packages/durabletask/tests/test_durable_agent_state.py b/python/packages/durabletask/tests/test_durable_agent_state.py index d3a36c9a7e4..62f0522779f 100644 --- a/python/packages/durabletask/tests/test_durable_agent_state.py +++ b/python/packages/durabletask/tests/test_durable_agent_state.py @@ -344,8 +344,8 @@ def test_usage_from_usage_details_with_extension_fields(self) -> None: "total_token_count": 300, } # Add provider-specific fields (UsageDetails is a TypedDict but allows extra keys) - usage_details["prompt_tokens"] = 100 # type: ignore[typeddict-unknown-key] - usage_details["completion_tokens"] = 200 # type: ignore[typeddict-unknown-key] + usage_details["prompt_tokens"] = 100 # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] + usage_details["completion_tokens"] = 200 # type: ignore[typeddict-unknown-key] # ty: ignore[invalid-key] usage = DurableAgentStateUsage.from_usage(usage_details) diff --git a/python/packages/durabletask/tests/test_durable_entities.py b/python/packages/durabletask/tests/test_durable_entities.py index d2378297945..97d3e6e10c9 100644 --- a/python/packages/durabletask/tests/test_durable_entities.py +++ b/python/packages/durabletask/tests/test_durable_entities.py @@ -181,7 +181,7 @@ def _make_durabletask_entity_provider( entity = DurableTaskEntityStateProvider() ctx = MockEntityContext(initial_state) # DurableEntity provides this hook; required for get_state/set_state to work in unit tests. - entity._initialize_entity_context(ctx) # type: ignore[attr-defined] + entity._initialize_entity_context(ctx) # type: ignore[attr-defined, arg-type] # ty: ignore[invalid-argument-type] return entity, ctx def test_reset_persists_cleared_state(self) -> None: diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py index 9f7cde156ca..97d5bd1386b 100644 --- a/python/packages/durabletask/tests/test_orchestration_context.py +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -36,7 +36,7 @@ def test_get_agent_returns_durable_agent_shim(self, agent_context: DurableAIAgen agent = agent_context.get_agent("assistant") assert isinstance(agent, DurableAIAgent) - assert isinstance(agent, SupportsAgentRun) + assert isinstance(agent, SupportsAgentRun) # pyrefly: ignore[unsafe-overlap] def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None: """Verify retrieved agent has the correct name.""" diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 9a265ab974d..de343f30aab 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -6,7 +6,7 @@ Run with: pytest tests/test_shim.py -v """ -from typing import Any +from typing import Any, cast from unittest.mock import Mock import pytest @@ -41,8 +41,8 @@ def create_run_request( opts = dict(options) if options else {} response_format = opts.pop("response_format", None) - enable_tool_calls = opts.pop("enable_tool_calls", True) - wait_for_response = opts.pop("wait_for_response", True) + enable_tool_calls = cast("bool", opts.pop("enable_tool_calls", True)) + wait_for_response = cast("bool", opts.pop("wait_for_response", True)) return RunRequest( message=message, correlation_id=str(uuid.uuid4()), @@ -147,7 +147,7 @@ class TestDurableAISupportsAgentRunCompliance: def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent implements SupportsAgentRun.""" - assert isinstance(test_agent, SupportsAgentRun) + assert isinstance(test_agent, SupportsAgentRun) # pyrefly: ignore[unsafe-overlap] def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None: """Verify DurableAIAgent has all required SupportsAgentRun properties.""" diff --git a/python/packages/foundry/agent_framework_foundry/_agent.py b/python/packages/foundry/agent_framework_foundry/_agent.py index 7001e0bf81e..bdecb828254 100644 --- a/python/packages/foundry/agent_framework_foundry/_agent.py +++ b/python/packages/foundry/agent_framework_foundry/_agent.py @@ -41,17 +41,17 @@ from ._tools import _sanitize_foundry_response_tool # pyright: ignore[reportPrivateUsage] if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from agent_framework import ( @@ -145,7 +145,7 @@ def _build_agent_reference(agent_name: str, agent_version: str | None) -> dict[s return ref -class RawFoundryAgentChatClient( # type: ignore[misc] +class RawFoundryAgentChatClient( RawOpenAIChatClient[FoundryAgentOptionsT], Generic[FoundryAgentOptionsT], ): @@ -476,9 +476,7 @@ async def get_agent_version(self) -> str | None: return self.agent_version if not self.allow_preview: return None - agent_details = await cast(Any, self.project_client.beta.agents).get( # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] - agent_name=self.agent_name - ) + agent_details = await cast(Any, self.project_client.beta.agents).get(agent_name=self.agent_name) versions_object = getattr(agent_details, "versions", None) if not isinstance(versions_object, Mapping): raise TypeError("Foundry agent details did not include a versions mapping.") @@ -496,7 +494,7 @@ async def close(self) -> None: await self.project_client.close() -class _FoundryAgentChatClient( # type: ignore[misc] +class _FoundryAgentChatClient( FunctionInvocationLayer[FoundryAgentOptionsT], ChatMiddlewareLayer[FoundryAgentOptionsT], ChatTelemetryLayer[FoundryAgentOptionsT], @@ -586,7 +584,7 @@ def __init__( ) -class RawFoundryAgent( # type: ignore[misc] +class RawFoundryAgent( RawAgent[FoundryAgentOptionsT], ): """Raw Microsoft Foundry Agent without agent-level middleware or telemetry. @@ -705,7 +703,7 @@ def __init__( id=id, name=name or agent_name, description=description, - tools=tools, # type: ignore[arg-type] + tools=tools, default_options=cast(FoundryAgentOptionsT | None, default_options), context_providers=context_providers, middleware=middleware, @@ -742,7 +740,7 @@ async def _create_service_session_id( if version := await self.client.get_agent_version(): from azure.ai.projects.models import VersionRefIndicator - create_session_kwargs["version_indicator"] = VersionRefIndicator(agent_version=version) # type: ignore + create_session_kwargs["version_indicator"] = VersionRefIndicator(agent_version=version) service_session = await self.client.project_client.beta.agents.create_session(**create_session_kwargs) agent_session_id = getattr(service_session, "agent_session_id", None) diff --git a/python/packages/foundry/agent_framework_foundry/_chat_client.py b/python/packages/foundry/agent_framework_foundry/_chat_client.py index 6d7dc878ff9..b55d2d101e4 100644 --- a/python/packages/foundry/agent_framework_foundry/_chat_client.py +++ b/python/packages/foundry/agent_framework_foundry/_chat_client.py @@ -59,17 +59,17 @@ from ._tools import _sanitize_foundry_response_tool # pyright: ignore[reportPrivateUsage] if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from agent_framework import ChatAndFunctionMiddlewareTypes, ToolTypes @@ -131,7 +131,7 @@ def resolve_file_ids(file_ids: Sequence[str | Content] | None) -> list[str] | No FoundryChatOptions = OpenAIChatOptions -class RawFoundryChatClient( # type: ignore[misc] +class RawFoundryChatClient( RawOpenAIChatClient[FoundryChatOptionsT], Generic[FoundryChatOptionsT], ): @@ -149,8 +149,8 @@ class RawFoundryChatClient( # type: ignore[misc] for a fully-featured client with middleware, telemetry, and function invocation. """ - OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" # type: ignore[reportIncompatibleVariableOverride, misc] - SUPPORTS_RICH_FUNCTION_OUTPUT: ClassVar[bool] = False # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" + SUPPORTS_RICH_FUNCTION_OUTPUT: ClassVar[bool] = False def __init__( self, @@ -219,7 +219,7 @@ def __init__( raise ValueError("Azure credential is required when using project_endpoint without a project_client.") project_client_kwargs: dict[str, Any] = { "endpoint": project_endpoint, - "credential": credential, # type: ignore[arg-type] + "credential": credential, "user_agent": get_user_agent(), } if allow_preview is not None: @@ -394,7 +394,7 @@ def get_file_search_tool( ) @staticmethod - def get_web_search_tool( # type: ignore[override] + def get_web_search_tool( *, user_location: dict[str, str] | None = None, search_context_size: Literal["low", "medium", "high"] | None = None, @@ -581,7 +581,7 @@ def get_bing_custom_search_tool( ) @staticmethod - def get_image_generation_tool( # type: ignore[override] + def get_image_generation_tool( *, model: Literal["gpt-image-1"] | str | None = None, size: Literal["1024x1024", "1024x1536", "1536x1024", "auto"] | None = None, @@ -609,8 +609,8 @@ def get_image_generation_tool( # type: ignore[override] Returns: An ImageGenTool ready to pass to an Agent. """ - return ImageGenTool( # type: ignore[misc] - model=model, # type: ignore[arg-type] + return ImageGenTool( + model=model, size=size, output_format=output_format, quality=quality, @@ -897,7 +897,7 @@ def get_a2a_tool( # endregion -class FoundryChatClient( # type: ignore[misc] +class FoundryChatClient( FunctionInvocationLayer[FoundryChatOptionsT], ChatMiddlewareLayer[FoundryChatOptionsT], ChatTelemetryLayer[FoundryChatOptionsT], @@ -952,7 +952,7 @@ class FoundryChatClient( # type: ignore[misc] ) """ - OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.foundry" def __init__( self, diff --git a/python/packages/foundry/agent_framework_foundry/_embedding_client.py b/python/packages/foundry/agent_framework_foundry/_embedding_client.py index 2e63bf7f563..cc9668ec4fa 100644 --- a/python/packages/foundry/agent_framework_foundry/_embedding_client.py +++ b/python/packages/foundry/agent_framework_foundry/_embedding_client.py @@ -23,9 +23,9 @@ from azure.core.credentials import AzureKeyCredential if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover logger = logging.getLogger("agent_framework.foundry") @@ -205,7 +205,7 @@ async def get_embeddings( ValueError: If model is not provided or an unsupported content type is encountered. """ if not values: - return GeneratedEmbeddings([], options=options) # type: ignore[reportReturnType] + return GeneratedEmbeddings([], options=options) opts: dict[str, Any] = dict(options) if options else {} @@ -307,7 +307,7 @@ async def get_embeddings( [embedding for embedding in embeddings if embedding is not None], options=options, usage=usage_details, - ) # type: ignore[reportReturnType] + ) class FoundryEmbeddingClient( @@ -363,7 +363,7 @@ class FoundryEmbeddingClient( result = await client.get_embeddings(["hello", image]) """ - OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.inference" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "azure.ai.inference" def __init__( self, diff --git a/python/packages/foundry/agent_framework_foundry/_foundry_evals.py b/python/packages/foundry/agent_framework_foundry/_foundry_evals.py index 8059c2ce990..8ee48a4cbf9 100644 --- a/python/packages/foundry/agent_framework_foundry/_foundry_evals.py +++ b/python/packages/foundry/agent_framework_foundry/_foundry_evals.py @@ -692,8 +692,8 @@ async def _evaluate_via_responses_impl( """ eval_obj = await client.evals.create( name=eval_name, - data_source_config={"type": "azure_ai_source", "scenario": "responses"}, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - testing_criteria=_build_testing_criteria(evaluators, model), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source_config={"type": "azure_ai_source", "scenario": "responses"}, # type: ignore[arg-type] + testing_criteria=_build_testing_criteria(evaluators, model), # type: ignore[arg-type] ) data_source = { @@ -711,7 +711,7 @@ async def _evaluate_via_responses_impl( run = await client.evals.runs.create( eval_id=eval_obj.id, name=f"{eval_name} Run", - data_source=data_source, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source=data_source, # type: ignore[arg-type] ) return await _poll_eval_run(client, eval_obj.id, run.id, poll_interval, timeout, provider=provider) @@ -926,14 +926,14 @@ async def _evaluate_via_dataset( eval_obj = await self._client.evals.create( name=eval_name, - data_source_config={ # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source_config={ "type": "custom", "item_schema": _build_item_schema( has_context=has_context, has_ground_truth=has_ground_truth, has_tools=has_tools ), "include_sample_schema": True, }, - testing_criteria=_build_testing_criteria( # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + testing_criteria=_build_testing_criteria( # type: ignore[arg-type] evaluators, self._model, include_data_mapping=True, @@ -952,7 +952,7 @@ async def _evaluate_via_dataset( run = await self._client.evals.runs.create( eval_id=eval_obj.id, name=f"{eval_name} Run", - data_source=data_source, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source=data_source, # type: ignore[arg-type] ) return await _poll_eval_run( @@ -1048,14 +1048,14 @@ async def evaluate_traces( eval_obj = await oai_client.evals.create( name=eval_name, - data_source_config={"type": "azure_ai_source", "scenario": "traces"}, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - testing_criteria=_build_testing_criteria(resolved_evaluators, model), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source_config={"type": "azure_ai_source", "scenario": "traces"}, # type: ignore[arg-type] + testing_criteria=_build_testing_criteria(resolved_evaluators, model), # type: ignore[arg-type] ) run = await oai_client.evals.runs.create( eval_id=eval_obj.id, name=f"{eval_name} Run", - data_source=trace_source, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source=trace_source, # type: ignore[arg-type] ) return await _poll_eval_run(oai_client, eval_obj.id, run.id, poll_interval, timeout) @@ -1111,11 +1111,11 @@ async def evaluate_foundry_target( eval_obj = await oai_client.evals.create( name=eval_name, - data_source_config={ # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source_config={ # type: ignore[arg-type] "type": "azure_ai_source", "scenario": "target_completions", }, - testing_criteria=_build_testing_criteria(resolved_evaluators, model), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + testing_criteria=_build_testing_criteria(resolved_evaluators, model), # type: ignore[arg-type] ) data_source: dict[str, Any] = { @@ -1130,7 +1130,7 @@ async def evaluate_foundry_target( run = await oai_client.evals.runs.create( eval_id=eval_obj.id, name=f"{eval_name} Run", - data_source=data_source, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + data_source=data_source, # type: ignore[arg-type] ) return await _poll_eval_run(oai_client, eval_obj.id, run.id, poll_interval, timeout) diff --git a/python/packages/foundry/agent_framework_foundry/_memory_provider.py b/python/packages/foundry/agent_framework_foundry/_memory_provider.py index 169da4fc853..41868cd396c 100644 --- a/python/packages/foundry/agent_framework_foundry/_memory_provider.py +++ b/python/packages/foundry/agent_framework_foundry/_memory_provider.py @@ -118,7 +118,7 @@ def __init__( raise ValueError("Azure credential is required when project_client is not provided.") project_client_kwargs: dict[str, Any] = { "endpoint": resolved_endpoint, - "credential": credential, # type: ignore[arg-type] + "credential": credential, "user_agent": get_user_agent(), } if allow_preview is not None: diff --git a/python/packages/foundry/agent_framework_foundry/_to_prompt_agent.py b/python/packages/foundry/agent_framework_foundry/_to_prompt_agent.py index bb835ff63c3..dc24a7808af 100644 --- a/python/packages/foundry/agent_framework_foundry/_to_prompt_agent.py +++ b/python/packages/foundry/agent_framework_foundry/_to_prompt_agent.py @@ -168,7 +168,7 @@ def _prepare_prompt_agent_options( ToolChoiceAllowed, ToolChoiceFunction, ) - from openai.lib._parsing._responses import ( # type: ignore[reportPrivateImportUsage] + from openai.lib._parsing._responses import ( type_to_text_format_param, ) from pydantic import BaseModel @@ -320,4 +320,4 @@ def _validate_mapping_tool(tool_item: Mapping[str, Any]) -> Tool: # ``_deserialize`` is the SDK's discriminator-aware entry point. It is marked # protected by convention but is the standard way to rehydrate polymorphic # azure-sdk-for-python models from a raw mapping. - return cast("Tool", ProjectsTool._deserialize(dict(tool_item), [])) # type: ignore[no-untyped-call] # pyright: ignore[reportPrivateUsage, reportUnknownMemberType] + return cast("Tool", ProjectsTool._deserialize(dict(tool_item), [])) # type: ignore[no-untyped-call] diff --git a/python/packages/foundry/tests/foundry/test_foundry_agent.py b/python/packages/foundry/tests/foundry/test_foundry_agent.py index 672a7aba690..87eb8f2e9e6 100644 --- a/python/packages/foundry/tests/foundry/test_foundry_agent.py +++ b/python/packages/foundry/tests/foundry/test_foundry_agent.py @@ -6,7 +6,7 @@ import os import sys from types import SimpleNamespace -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -196,7 +196,7 @@ class CustomClient(RawFoundryAgentChatClient): named_agent = client.as_agent(name="display-name", instructions="You are helpful.") assert named_agent.name == "display-name" - assert named_agent.client.agent_name == "test-agent" + assert cast(Any, named_agent.client).agent_name == "test-agent" def test_raw_foundry_agent_chat_client_as_agent_uses_explicit_parameters() -> None: @@ -646,7 +646,7 @@ def test_raw_foundry_agent_init_creates_client() -> None: ) assert agent.client is not None - assert agent.client.agent_name == "test-agent" + assert cast(Any, agent.client).agent_name == "test-agent" def test_raw_foundry_agent_init_passes_default_headers_to_client() -> None: @@ -740,7 +740,7 @@ def test_foundry_agent_init_propagates_timeout_to_openai_client() -> None: openai_client_mock.with_options.assert_called_once_with(timeout=90.0) assert openai_client_mock.timeout == 5.0, "Original shared client must not be mutated" - assert agent.client.client is openai_client_mock.with_options.return_value + assert cast(Any, agent.client).client is openai_client_mock.with_options.return_value def test_foundry_agent_init_timeout_none_leaves_client_default() -> None: @@ -768,7 +768,7 @@ def test_raw_foundry_agent_init_rejects_invalid_client_type() -> None: RawFoundryAgent( project_client=MagicMock(), agent_name="test-agent", - client_type=object, # type: ignore[arg-type] + client_type=cast(Any, object), ) @@ -874,7 +874,7 @@ def test_foundry_agent_init() -> None: ) assert agent.client is not None - assert agent.client.agent_name == "test-agent" + assert cast(Any, agent.client).agent_name == "test-agent" def test_foundry_agent_init_with_middleware() -> None: @@ -884,7 +884,7 @@ def test_foundry_agent_init_with_middleware() -> None: mock_project.get_openai_client.return_value = MagicMock() class MyMiddleware(ChatMiddleware): - async def process(self, context: ChatContext) -> None: + async def process(self, context: ChatContext, call_next) -> None: pass agent = FoundryAgent( @@ -980,7 +980,7 @@ def _import_with_missing_azure_monitor( @skip_if_foundry_agent_integration_tests_disabled async def test_foundry_agent_basic_run() -> None: """Smoke-test FoundryAgent against a real configured agent.""" - async with FoundryAgent(credential=AzureCliCredential(), allow_preview=True) as agent: + async with FoundryAgent(credential=cast(Any, AzureCliCredential()), allow_preview=True) as agent: response = await agent.run("Please respond with exactly: 'This is a response test.'") assert isinstance(response, AgentResponse) @@ -994,7 +994,7 @@ async def test_foundry_agent_basic_run() -> None: async def test_foundry_agent_custom_client_run() -> None: """Smoke-test FoundryAgent against a real configured agent.""" async with FoundryAgent( - credential=AzureCliCredential(), client_type=RawFoundryAgentChatClient, allow_preview=True + credential=cast(Any, AzureCliCredential()), client_type=RawFoundryAgentChatClient, allow_preview=True ) as agent: response = await agent.run("Please respond with exactly: 'This is a response test.'") diff --git a/python/packages/foundry/tests/foundry/test_foundry_chat_client.py b/python/packages/foundry/tests/foundry/test_foundry_chat_client.py index 9465aa0ff2e..9e9922e3e9a 100644 --- a/python/packages/foundry/tests/foundry/test_foundry_chat_client.py +++ b/python/packages/foundry/tests/foundry/test_foundry_chat_client.py @@ -8,7 +8,7 @@ import warnings from functools import wraps from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -75,7 +75,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: f"model={os.getenv('FOUNDRY_MODEL', '')}" ) if hasattr(exc, "add_note"): - exc.add_note(debug_message) + cast(Any, exc).add_note(debug_message) elif exc.args: exc.args = (f"{exc.args[0]}\n{debug_message}", *exc.args[1:]) else: @@ -155,8 +155,8 @@ def test_init() -> None: client = FoundryChatClient(project_client=mock_project_client, model=_TEST_FOUNDRY_MODEL) assert client.model == _TEST_FOUNDRY_MODEL - assert isinstance(client, SupportsChatGetResponse) assert client.project_client is mock_project_client + assert isinstance(client, SupportsChatGetResponse) def test_raw_foundry_chat_client_init_uses_explicit_parameters() -> None: @@ -375,6 +375,7 @@ async def test_web_search_tool_with_location() -> None: } ) + assert web_search_tool.user_location is not None assert web_search_tool.user_location.city == "Seattle" assert web_search_tool.user_location.country == "US" _, run_options, _ = await client._prepare_request( @@ -393,7 +394,7 @@ async def test_code_interpreter_tool_variations() -> None: client = FoundryChatClient(project_client=project_client, model="test-model") code_tool = FoundryChatClient.get_code_interpreter_tool() - assert code_tool.container["type"] == "auto" + assert cast(dict[str, Any], code_tool.container)["type"] == "auto" _, run_options, _ = await client._prepare_request( messages=[Message("user", ["Run some code"])], @@ -403,7 +404,7 @@ async def test_code_interpreter_tool_variations() -> None: assert run_options["tools"] == [code_tool] code_tool_with_files = FoundryChatClient.get_code_interpreter_tool(file_ids=["file1", "file2"]) - assert code_tool_with_files.container.file_ids == ["file1", "file2"] + assert cast(Any, code_tool_with_files.container).file_ids == ["file1", "file2"] _, run_options, _ = await client._prepare_request( messages=[Message(role="user", contents=["Process these files"])], @@ -487,7 +488,7 @@ async def test_content_filter_exception() -> None: body={"error": {"code": "content_filter", "message": "Content filter error"}}, ) mock_error.code = "content_filter" - client.client.responses.with_raw_response.create.side_effect = mock_error + cast(Any, client.client.responses.with_raw_response.create).side_effect = mock_error with pytest.raises(OpenAIContentFilterException) as exc_info: await client.get_response(messages=[Message(role="user", contents=["Test message"])]) @@ -852,7 +853,7 @@ async def test_integration_options( option_value: Any, needs_validation: bool, ) -> None: - client = FoundryChatClient(credential=AzureCliCredential()) + client = FoundryChatClient(credential=cast(Any, AzureCliCredential())) client.function_invocation_configuration["max_iterations"] = 2 if option_name.startswith("tools") or option_name.startswith("tool_choice"): @@ -867,7 +868,9 @@ async def test_integration_options( if option_name.startswith("tool_choice"): options["tools"] = [get_weather] - response = await client.get_response(messages=messages, options=options, stream=True).get_final_response() + response = await client.get_response( + messages=messages, options=cast(Any, options), stream=True + ).get_final_response() assert isinstance(response, ChatResponse) assert response.text is not None @@ -893,19 +896,19 @@ async def test_integration_options( @skip_if_foundry_integration_tests_disabled @_with_foundry_debug() async def test_integration_web_search() -> None: - client = FoundryChatClient(credential=AzureCliCredential()) + client = FoundryChatClient(credential=cast(Any, AzureCliCredential())) web_search_tool = FoundryChatClient.get_web_search_tool() - content = { - "messages": [ - Message( - role="user", - contents=["Where is Microsoft's headquarters? Do a web search to find the answer."], - ) - ], - "options": {"tool_choice": "auto", "tools": [web_search_tool]}, - } - response = await client.get_response(stream=True, **content).get_final_response() + messages = [ + Message( + role="user", + contents=["Where is Microsoft's headquarters? Do a web search to find the answer."], + ) + ] + options: dict[str, Any] = {"tool_choice": "auto", "tools": [web_search_tool]} + response = await client.get_response( + messages=messages, options=cast(Any, options), stream=True + ).get_final_response() assert isinstance(response, ChatResponse) assert "redmond" in response.text.lower() @@ -924,13 +927,15 @@ async def test_integration_tool_rich_content_image() -> None: def get_test_image() -> Content: return Content.from_data(data=image_bytes, media_type="image/jpeg") - client = FoundryChatClient(credential=AzureCliCredential()) + client = FoundryChatClient(credential=cast(Any, AzureCliCredential())) client.function_invocation_configuration["max_iterations"] = 2 messages = [Message(role="user", contents=["Call the get_test_image tool and describe what you see."])] options: dict[str, Any] = {"tools": [get_test_image], "tool_choice": "auto"} - response = await client.get_response(messages=messages, options=options, stream=True).get_final_response() + response = await client.get_response( + messages=messages, options=cast(Any, options), stream=True + ).get_final_response() assert isinstance(response, ChatResponse) assert response.text is not None diff --git a/python/packages/foundry/tests/foundry/test_foundry_memory_provider.py b/python/packages/foundry/tests/foundry/test_foundry_memory_provider.py index 6377dfa602f..89d9023602f 100644 --- a/python/packages/foundry/tests/foundry/test_foundry_memory_provider.py +++ b/python/packages/foundry/tests/foundry/test_foundry_memory_provider.py @@ -4,6 +4,7 @@ from __future__ import annotations import os +from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch import pytest @@ -160,7 +161,7 @@ async def test_retrieves_static_memories_on_first_run(mock_project_client: Async ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) # Should call search_memories twice: once for static, once for contextual @@ -195,7 +196,7 @@ async def test_contextual_memories_added_to_context(mock_project_client: AsyncMo ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) # Check that memories were added to context @@ -222,7 +223,7 @@ async def test_empty_input_skips_contextual_search(mock_project_client: AsyncMoc ctx = SessionContext(input_messages=[Message(role="user", contents=[""])], session_id="s1") await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) # Should only call search_memories once for static memories @@ -244,7 +245,7 @@ async def test_empty_search_results_no_messages(mock_project_client: AsyncMock) ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) assert provider.source_id not in ctx.context_messages @@ -270,7 +271,7 @@ async def test_static_memories_only_retrieved_once(mock_project_client: AsyncMoc # First call await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) assert mock_project_client.beta.memory_stores.search_memories.call_count == 2 @@ -283,7 +284,7 @@ async def test_static_memories_only_retrieved_once(mock_project_client: AsyncMoc # Second call - should only search contextual, not static ctx2 = SessionContext(input_messages=[Message(role="user", contents=["World"])], session_id="s1") await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx2, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx2, state=session.state.setdefault(provider.source_id, {}) ) assert mock_project_client.beta.memory_stores.search_memories.call_count == 1 @@ -301,7 +302,7 @@ async def test_handles_search_exception_gracefully(mock_project_client: AsyncMoc # Should not raise exception await provider.before_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) # No memories added @@ -326,7 +327,7 @@ async def test_stores_input_and_response(mock_project_client: AsyncMock) -> None ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["answer"])]) await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) mock_project_client.beta.memory_stores.begin_update_memories.assert_awaited_once() @@ -359,7 +360,7 @@ async def test_only_stores_user_assistant_system(mock_project_client: AsyncMock) ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])]) await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) call_kwargs = mock_project_client.beta.memory_stores.begin_update_memories.call_args.kwargs @@ -386,7 +387,7 @@ async def test_skips_empty_messages(mock_project_client: AsyncMock) -> None: ctx._response = AgentResponse(messages=[]) await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) mock_project_client.beta.memory_stores.begin_update_memories.assert_not_awaited() @@ -407,7 +408,7 @@ async def test_uses_configured_update_delay(mock_project_client: AsyncMock) -> N ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hey"])]) await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) call_kwargs = mock_project_client.beta.memory_stores.begin_update_memories.call_args.kwargs @@ -433,7 +434,7 @@ async def test_uses_previous_update_id_for_incremental_updates(mock_project_clie # First update await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx1, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx1, state=session.state.setdefault(provider.source_id, {}) ) assert session.state[provider.source_id]["previous_update_id"] == "update-1" @@ -442,7 +443,7 @@ async def test_uses_previous_update_id_for_incremental_updates(mock_project_clie ctx2._response = AgentResponse(messages=[Message(role="assistant", contents=["response2"])]) await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx2, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx2, state=session.state.setdefault(provider.source_id, {}) ) call_kwargs = mock_project_client.beta.memory_stores.begin_update_memories.call_args.kwargs @@ -464,7 +465,7 @@ async def test_handles_update_exception_gracefully(mock_project_client: AsyncMoc # Should not raise exception await provider.after_run( # type: ignore[arg-type] - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) diff --git a/python/packages/foundry/tests/foundry/test_to_prompt_agent.py b/python/packages/foundry/tests/foundry/test_to_prompt_agent.py index 599aa73d2d5..ffa56874999 100644 --- a/python/packages/foundry/tests/foundry/test_to_prompt_agent.py +++ b/python/packages/foundry/tests/foundry/test_to_prompt_agent.py @@ -192,7 +192,7 @@ def test_to_prompt_agent_does_not_mutate_default_options() -> None: def test_to_prompt_agent_passes_through_sdk_tool_instances() -> None: """Foundry SDK tool instances (e.g. WebSearchTool) are passed through unchanged.""" ws = WebSearchTool() - ci = CodeInterpreterTool(container={"type": "auto"}) + ci = CodeInterpreterTool({"container": {"type": "auto"}}) agent = _make_agent(_make_foundry_chat_client(), instructions="x", tools=[ws, ci]) definition = to_prompt_agent(agent) @@ -616,7 +616,7 @@ def test_to_prompt_agent_forwards_structured_inputs_kwarg() -> None: def test_to_prompt_agent_forwards_rai_config_kwarg() -> None: """A ``RaiConfig`` kwarg is forwarded to the definition.""" - rai_config = RaiConfig() + rai_config = RaiConfig(rai_policy_name="test-policy") agent = _make_agent(_make_foundry_chat_client(), instructions="x") definition = to_prompt_agent(agent, rai_config=rai_config) @@ -631,7 +631,7 @@ def test_to_prompt_agent_forwards_rai_config_kwarg() -> None: def test_to_prompt_agent_combines_all_sources() -> None: """Generation params from default_options + Foundry-only kwargs combine cleanly.""" - rai_config = RaiConfig() + rai_config = RaiConfig(rai_policy_name="test-policy") structured = {"q": StructuredInputDefinition(description="query")} agent = _make_agent( _make_foundry_chat_client(), diff --git a/python/packages/foundry/tests/test_foundry_evals.py b/python/packages/foundry/tests/test_foundry_evals.py index 8734650aafb..6b74f1331ed 100644 --- a/python/packages/foundry/tests/test_foundry_evals.py +++ b/python/packages/foundry/tests/test_foundry_evals.py @@ -6,7 +6,7 @@ import json from dataclasses import dataclass -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -496,7 +496,7 @@ def test_split_messages_full_split(self) -> None: Message("assistant", ["Rain is expected tomorrow."]), ] item = EvalItem(conversation=conversation) - query_msgs, response_msgs = item.split_messages(split=ConversationSplit.FULL) + query_msgs, response_msgs = item.split_messages(split=cast(Any, ConversationSplit.FULL)) # query_messages: just the first user message assert len(query_msgs) == 1 assert query_msgs[0].role == "user" @@ -515,7 +515,7 @@ def test_split_messages_full_split_with_system(self) -> None: Message("assistant", ["It's sunny."]), ] item = EvalItem(conversation=conversation) - query_msgs, response_msgs = item.split_messages(split=ConversationSplit.FULL) + query_msgs, response_msgs = item.split_messages(split=cast(Any, ConversationSplit.FULL)) # query includes system + first user assert len(query_msgs) == 2 assert query_msgs[0].role == "system" @@ -533,7 +533,7 @@ def test_split_messages_full_split_with_tools(self) -> None: Message("assistant", ["You're welcome!"]), ] item = EvalItem(conversation=conversation) - query_msgs, response_msgs = item.split_messages(split=ConversationSplit.FULL) + query_msgs, response_msgs = item.split_messages(split=cast(Any, ConversationSplit.FULL)) assert len(query_msgs) == 1 assert len(response_msgs) == 5 @@ -547,7 +547,7 @@ def test_split_messages_last_turn_is_default(self) -> None: ] item = EvalItem(conversation=conversation) q_default, r_default = item.split_messages() - q_explicit, r_explicit = item.split_messages(split=ConversationSplit.LAST_TURN) + q_explicit, r_explicit = item.split_messages(split=cast(Any, ConversationSplit.LAST_TURN)) assert [m.role for m in q_default] == [m.role for m in q_explicit] assert [m.text for m in q_default] == [m.text for m in q_explicit] assert [m.role for m in r_default] == [m.role for m in r_explicit] @@ -585,7 +585,7 @@ def test_per_turn_items_with_tools(self) -> None: Message("assistant", ["You're welcome!"]), ] tool_objs = [_make_tool("get_weather")] - items = EvalItem.per_turn_items(conversation, tools=tool_objs) + items = EvalItem.per_turn_items(conversation, tools=cast(Any, tool_objs)) assert len(items) == 2 # Turn 1: response includes tool_call, tool_result, and final assistant @@ -624,13 +624,13 @@ def test_custom_splitter_callable(self) -> None: Message("assistant", ["The capital of France is Paris, Alice!"]), ] - def split_before_memory(conv): + def split_before_memory(conversation): """Split just before the memory retrieval tool call.""" - for i, msg in enumerate(conv): + for i, msg in enumerate(conversation): for c in msg.contents: if c.name == "retrieve_memory": - return conv[:i], conv[i:] - return EvalItem._split_last_turn_static(conv) + return conversation[:i], conversation[i:] + return EvalItem._split_last_turn_static(conversation) item = EvalItem(conversation=conversation) query_msgs, response_msgs = item.split_messages(split=split_before_memory) @@ -650,12 +650,12 @@ def test_custom_splitter_with_fallback(self) -> None: Message("assistant", ["Hi there!"]), ] - def split_before_memory(conv): - for i, msg in enumerate(conv): + def split_before_memory(conversation): + for i, msg in enumerate(conversation): for c in msg.contents: if c.name == "retrieve_memory": - return conv[:i], conv[i:] - return EvalItem._split_last_turn_static(conv) + return conversation[:i], conversation[i:] + return EvalItem._split_last_turn_static(conversation) item = EvalItem(conversation=conversation) query_msgs, response_msgs = item.split_messages(split=split_before_memory) @@ -675,7 +675,7 @@ def test_custom_splitter_lambda(self) -> None: ] # Split at index 2 (arbitrary) item = EvalItem(conversation=conversation) - query_msgs, response_msgs = item.split_messages(split=lambda conv: (conv[:2], conv[2:])) + query_msgs, response_msgs = item.split_messages(split=lambda conversation: (conversation[:2], conversation[2:])) assert len(query_msgs) == 2 assert len(response_msgs) == 2 @@ -689,7 +689,7 @@ def test_split_strategy_on_item_used_by_split_messages(self) -> None: ] item = EvalItem( conversation=conversation, - split_strategy=ConversationSplit.FULL, + split_strategy=cast(Any, ConversationSplit.FULL), ) # split_messages() with no split arg should use item.split_strategy query_msgs, response_msgs = item.split_messages() @@ -707,10 +707,10 @@ def test_explicit_split_overrides_item_split_strategy(self) -> None: ] item = EvalItem( conversation=conversation, - split_strategy=ConversationSplit.FULL, + split_strategy=cast(Any, ConversationSplit.FULL), ) # Explicit split= should override split_strategy - query_msgs, response_msgs = item.split_messages(split=ConversationSplit.LAST_TURN) + query_msgs, response_msgs = item.split_messages(split=cast(Any, ConversationSplit.LAST_TURN)) assert len(query_msgs) == 3 # LAST_TURN: up to last user assert query_msgs[-1].text == "Second" assert len(response_msgs) == 1 @@ -1809,7 +1809,7 @@ def test_extracts_single_agent(self) -> None: WorkflowEvent.executor_invoked("planner", "Plan a trip"), WorkflowEvent.executor_completed("planner", [aer]), ] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) data = _extract_agent_eval_data(result) assert len(data) == 1 @@ -1826,7 +1826,7 @@ def test_extracts_multiple_agents(self) -> None: WorkflowEvent.executor_invoked("booker", "Book flight"), WorkflowEvent.executor_completed("booker", [aer2]), ] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) data = _extract_agent_eval_data(result) assert len(data) == 2 @@ -1844,7 +1844,7 @@ def test_skips_internal_executors(self) -> None: WorkflowEvent.executor_invoked("end", []), WorkflowEvent.executor_completed("end", None), ] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) data = _extract_agent_eval_data(result) assert len(data) == 1 @@ -1857,7 +1857,7 @@ def test_resolves_agent_from_workflow(self) -> None: WorkflowEvent.executor_invoked("my-agent", "Do it"), WorkflowEvent.executor_completed("my-agent", [aer]), ] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) # Build a mock workflow with AgentExecutor from agent_framework import AgentExecutor @@ -1878,13 +1878,13 @@ def test_resolves_agent_from_workflow(self) -> None: class TestExtractOverallQuery: def test_extracts_string_query(self) -> None: events = [WorkflowEvent.executor_invoked("input", "Plan a trip")] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) assert _extract_overall_query(result) == "Plan a trip" def test_extracts_message_query(self) -> None: msgs = [Message("user", ["What's the weather?"])] events = [WorkflowEvent.executor_invoked("input", msgs)] - result = WorkflowRunResult(events, []) + result = WorkflowRunResult(cast(Any, events), []) assert "What's the weather?" in (_extract_overall_query(result) or "") def test_returns_none_for_empty(self) -> None: @@ -1932,7 +1932,7 @@ async def test_post_hoc_with_workflow_result(self) -> None: WorkflowEvent.executor_completed("reviewer", [aer2]), WorkflowEvent("output", executor_id="end", data=final_output), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -1961,7 +1961,7 @@ async def test_with_queries_runs_workflow(self) -> None: WorkflowEvent.executor_completed("agent", [aer]), WorkflowEvent("output", executor_id="end", data=final_output), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -1991,7 +1991,7 @@ async def test_overall_plus_per_agent(self) -> None: WorkflowEvent.executor_completed("planner", [aer]), WorkflowEvent("output", executor_id="end", data=final_output), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -2028,7 +2028,7 @@ async def test_per_agent_only(self) -> None: WorkflowEvent.executor_invoked("agent-a", "Do stuff"), WorkflowEvent.executor_completed("agent-a", [aer]), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -2057,7 +2057,7 @@ async def test_overall_eval_excludes_tool_evaluators(self) -> None: WorkflowEvent.executor_completed("researcher", [aer]), WorkflowEvent("output", executor_id="end", data=[Message("assistant", ["Weather is sunny"])]), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -2098,7 +2098,7 @@ async def test_per_agent_excludes_tool_evaluators_when_no_tools(self) -> None: WorkflowEvent.executor_invoked("planner", "Plan based on: sunny"), WorkflowEvent.executor_completed("planner", [aer2]), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) from agent_framework import AgentExecutor @@ -2166,7 +2166,7 @@ async def test_expected_output_stamps_overall_items(self) -> None: WorkflowEvent.executor_completed("agent", [aer]), WorkflowEvent("output", executor_id="end", data=final_output), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} @@ -2205,7 +2205,7 @@ async def test_expected_output_with_num_repetitions(self) -> None: WorkflowEvent.executor_completed("agent", [aer]), WorkflowEvent("output", executor_id="end", data=final_output), ] - wf_result = WorkflowRunResult(events, []) + wf_result = WorkflowRunResult(cast(Any, events), []) mock_workflow = MagicMock() mock_workflow.executors = {} diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_invocations.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_invocations.py index 05105ec768e..df3c09c6234 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_invocations.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_invocations.py @@ -35,7 +35,7 @@ def __init__( self._agent = agent self._sessions: dict[str, AgentSession] = {} - self.invoke_handler(self._handle_invoke) # pyright: ignore[reportUnknownMemberType] + self.invoke_handler(self._handle_invoke) async def _handle_invoke(self, request: Request) -> Response: """Invoke the agent with the given request.""" diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index ec4a9d85336..2a18e697468 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -412,8 +412,8 @@ def __init__( # `oauth_consent_request` stream event instead of crashing the server. self._agent_stack: AsyncExitStack | None = None self._agent_init_lock = asyncio.Lock() - self.shutdown_handler(self._cleanup_agent) # pyright: ignore[reportUnknownMemberType] - self.response_handler(self._handle_response) # pyright: ignore[reportUnknownMemberType] + self.shutdown_handler(self._cleanup_agent) + self.response_handler(self._handle_response) async def _ensure_agent_ready(self) -> None: """Lazily enter the agent's async context exactly once. @@ -1717,7 +1717,7 @@ async def _to_outputs( # for round trips where the original approval request needs to be looked up. item = getattr(event, "item", None) if item is not None and getattr(item, "id", None) is not None: - approval_request_id = cast(str, item.id) # type: ignore + approval_request_id = cast(str, item.id) await approval_storage.save_approval_request(approval_request_id, content) request_saved = True yield event diff --git a/python/packages/foundry_hosting/tests/test_responses.py b/python/packages/foundry_hosting/tests/test_responses.py index 8bc09c5d13d..5d15c42e8d2 100644 --- a/python/packages/foundry_hosting/tests/test_responses.py +++ b/python/packages/foundry_hosting/tests/test_responses.py @@ -187,7 +187,24 @@ def test_init_basic(self) -> None: assert server is not None def test_init_rejects_history_provider_with_load_messages(self) -> None: - hp = HistoryProvider(source_id="test", load_messages=True) + + class _LoadMessagesHistoryProvider(HistoryProvider): + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: + return [] + + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + pass + + hp = _LoadMessagesHistoryProvider(source_id="test", load_messages=True) agent = _make_agent( response=AgentResponse(messages=[Message(role="assistant", contents=[Content.from_text("hi")])]) ) @@ -477,7 +494,7 @@ class HandoffLikeRequest: agent = _make_agent( stream_updates=[ AgentResponseUpdate( - contents=[Content.from_function_call("call_1", "handoff_to_refund", arguments=request)], + contents=[Content.from_function_call("call_1", "handoff_to_refund", arguments=request.__dict__)], role="assistant", ), ] @@ -778,7 +795,7 @@ async def test_function_call_output(self) -> None: from azure.ai.agentserver.responses.models import FunctionCallOutputItemParam item = FunctionCallOutputItemParam({"type": "function_call_output", "call_id": "call_1", "output": "sunny"}) - msg = await _output_item_to_message(item) # type: ignore[arg-type] + msg = await _output_item_to_message(item) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] assert msg.role == "tool" assert msg.contents[0].type == "function_result" assert msg.contents[0].call_id == "call_1" @@ -2592,7 +2609,9 @@ async def test_file_based_save_and_load_persists_across_instances(self, tmp_path assert loaded.type == "function_approval_request" assert loaded.id == "apr_1" # type: ignore[attr-defined] # The embedded function_call survives the round trip. - assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined] + function_call = loaded.function_call + assert function_call is not None + assert function_call.name == "delete_file" async def test_file_based_duplicate_save_raises(self, tmp_path: Any) -> None: path = tmp_path / "approvals.json" @@ -2632,7 +2651,9 @@ async def test_output_item_mcp_approval_request_loads_from_storage(self) -> None assert c.type == "function_approval_request" assert c.id == "apr-1" # type: ignore[attr-defined] # The full saved Content (incl. function_call) is restored. - assert c.function_call.name == "delete_file" # type: ignore[attr-defined] + function_call = c.function_call + assert function_call is not None + assert function_call.name == "delete_file" async def test_output_item_mcp_approval_request_without_storage_raises(self) -> None: from azure.ai.agentserver.responses.models import OutputItemMcpApprovalRequest @@ -2666,7 +2687,9 @@ async def test_output_item_mcp_approval_response_resolves_to_approval_response(s assert c.type == "function_approval_response" assert c.approved is True # type: ignore[attr-defined] assert c.id == "apr-1" # type: ignore[attr-defined] - assert c.function_call.name == "delete_file" # type: ignore[attr-defined] + function_call = c.function_call + assert function_call is not None + assert function_call.name == "delete_file" async def test_output_item_mcp_approval_response_without_storage_raises(self) -> None: from azure.ai.agentserver.responses.models import OutputItemMcpApprovalResponseResource @@ -2750,7 +2773,7 @@ async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage( approval_request_id ) assert loaded.type == "function_approval_request" - assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined] + assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None: request_content = _make_function_approval_request_content(request_id="apr_streaming") @@ -2769,7 +2792,7 @@ async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self for e in events: if e["event"] != "response.output_item.added": continue - item = e["data"].get("item") or {} + item: dict[str, Any] = e["data"].get("item") or {} if item.get("type") == "mcp_approval_request": approval_request_id = item.get("id") break @@ -3104,7 +3127,7 @@ def test_traversal_and_separator_payloads_are_rejected(self, tmp_path: Any, bad_ def test_non_string_context_id_is_rejected(self, tmp_path: Any) -> None: helper = self._helper() with pytest.raises(RuntimeError): - helper(str(tmp_path), None) # type: ignore[arg-type] + helper(str(tmp_path), None) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def test_url_encoded_traversal_is_treated_as_literal_segment(self, tmp_path: Any) -> None: """URL-encoded traversal should not decode to traversal at the filesystem layer. @@ -3516,7 +3539,7 @@ def __init__( def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: return AgentSession() def _next_request_id(self) -> str: @@ -3650,7 +3673,7 @@ def __init__(self, name: str, text: str) -> None: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: return AgentSession() @overload @@ -3796,7 +3819,7 @@ async def test_non_streaming_emits_mcp_approval_request_and_persists_to_storage( approval_request_id ) assert loaded.type == "function_approval_request" - assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined] + assert loaded.function_call.name == "delete_file" # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert mock_agent.run_count == 1 async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self) -> None: @@ -3815,7 +3838,7 @@ async def test_streaming_emits_mcp_approval_request_and_persists_to_storage(self for e in events: if e["event"] != "response.output_item.added": continue - item = e["data"].get("item") or {} + item: dict[str, Any] = e["data"].get("item") or {} if item.get("type") == "mcp_approval_request": approval_request_id = item.get("id") break diff --git a/python/packages/foundry_hosting/tests/test_responses_int.py b/python/packages/foundry_hosting/tests/test_responses_int.py index a67b86f00a3..26259157373 100644 --- a/python/packages/foundry_hosting/tests/test_responses_int.py +++ b/python/packages/foundry_hosting/tests/test_responses_int.py @@ -47,12 +47,12 @@ @pytest.fixture def server() -> ResponsesHostServer: """Create a ResponsesHostServer backed by a real Foundry agent.""" - client = FoundryChatClient(credential=AzureCliCredential()) + client = FoundryChatClient(credential=AzureCliCredential()) # pyrefly: ignore[bad-argument-type] agent = Agent( - client=client, + client=client, # ty: ignore[invalid-argument-type] instructions="You are a concise assistant. Keep answers very short (one or two sentences).", - default_options={"store": False}, + default_options={"store": False}, # pyrefly: ignore[bad-argument-type] ) return ResponsesHostServer(agent, store=InMemoryResponseProvider()) @@ -67,13 +67,13 @@ async def get_weather(location: Annotated[str, "The city name"]) -> str: @pytest.fixture def server_with_tools() -> ResponsesHostServer: """Create a ResponsesHostServer whose agent has a tool.""" - client = FoundryChatClient(credential=AzureCliCredential()) + client = FoundryChatClient(credential=AzureCliCredential()) # pyrefly: ignore[bad-argument-type] agent = Agent( - client=client, + client=client, # ty: ignore[invalid-argument-type] instructions="You are a concise assistant. Use the provided tools when appropriate. Keep answers very short.", tools=[get_weather], - default_options={"store": False}, + default_options={"store": False}, # pyrefly: ignore[bad-argument-type] ) return ResponsesHostServer(agent, store=InMemoryResponseProvider()) diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index f51a77c0544..2d627a51649 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -28,13 +28,13 @@ from pydantic import BaseModel if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover __all__ = [ @@ -320,7 +320,7 @@ class MyOptions(FoundryLocalChatOptions, total=False): env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) - model_setting: str = settings["model"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] + model_setting: str = settings["model"] # type: ignore[assignment] manager = FoundryLocalManager(bootstrap=bootstrap, timeout=timeout) model_info: FoundryModelInfo | None = manager.get_model_info( diff --git a/python/packages/gemini/agent_framework_gemini/_chat_client.py b/python/packages/gemini/agent_framework_gemini/_chat_client.py index 5ec5575fb75..fee44ee368c 100644 --- a/python/packages/gemini/agent_framework_gemini/_chat_client.py +++ b/python/packages/gemini/agent_framework_gemini/_chat_client.py @@ -35,19 +35,19 @@ from pydantic import BaseModel if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover logger = logging.getLogger("agent_framework.gemini") @@ -306,7 +306,7 @@ class RawGeminiChatClient( client with batteries included, use `GeminiChatClient` instead. """ - OTEL_PROVIDER_NAME: ClassVar[str] = "gcp.gemini" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "gcp.gemini" def __init__( self, @@ -535,7 +535,7 @@ def _inner_get_response( async def _stream() -> AsyncIterable[ChatResponseUpdate]: validated = await self._validate_options(options) model, contents, config = self._prepare_request(messages, validated) - async for chunk in await self._genai_client.aio.models.generate_content_stream( # pyright: ignore[reportUnknownMemberType] + async for chunk in await self._genai_client.aio.models.generate_content_stream( model=model, contents=contents, # type: ignore[arg-type] config=config, diff --git a/python/packages/gemini/tests/test_gemini_client.py b/python/packages/gemini/tests/test_gemini_client.py index 732ca6635b0..518a6f73d08 100644 --- a/python/packages/gemini/tests/test_gemini_client.py +++ b/python/packages/gemini/tests/test_gemini_client.py @@ -5,7 +5,7 @@ import datetime import logging import os -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -45,7 +45,7 @@ def _make_part( *, text: str | None = None, thought: bool = False, - function_call: tuple[str, str, dict[str, Any]] | None = None, + function_call: tuple[str | None, str, dict[str, Any]] | None = None, executable_code: str | None = None, code_execution_result: str | None = None, ) -> MagicMock: @@ -127,7 +127,7 @@ async def _async_iter(items: list[Any]): def _make_gemini_client( - model: str = "gemini-2.5-flash", + model: str | None = "gemini-2.5-flash", mock_client: MagicMock | None = None, ) -> tuple[GeminiChatClient, MagicMock]: """Return a (GeminiChatClient, mock_genai_client) pair.""" @@ -138,6 +138,25 @@ def _make_gemini_client( return client, mock +def _parts(content: types.Content) -> list[types.Part]: + assert content.parts is not None + return content.parts + + +def _function_calling_config(config: types.GenerateContentConfig) -> types.FunctionCallingConfig: + assert config.tool_config is not None + function_calling_config = config.tool_config.function_calling_config + assert function_calling_config is not None + return function_calling_config + + +def _first_function_declaration(config: types.GenerateContentConfig) -> types.FunctionDeclaration: + assert config.tools is not None + tool = cast(types.Tool, config.tools[0]) + assert tool.function_declarations is not None + return tool.function_declarations[0] + + # settings & initialisation @@ -429,6 +448,7 @@ async def test_multiple_system_messages_concatenated() -> None: ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert isinstance(config.system_instruction, str) assert "Be concise." in config.system_instruction assert "Use bullet points." in config.system_instruction @@ -447,6 +467,7 @@ async def test_instructions_option_merged_with_system_instruction() -> None: ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] + assert isinstance(config.system_instruction, str) assert "Always respond in French." in config.system_instruction assert "Be concise." in config.system_instruction @@ -508,7 +529,7 @@ async def test_tool_messages_collapsed_into_single_user_message() -> None: contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] # user, model (with 2 function calls), user (with 2 function responses) assert contents[-1].role == "user" - assert len(contents[-1].parts) == 2 + assert len(_parts(contents[-1])) == 2 async def test_function_result_name_resolved_from_call_history() -> None: @@ -530,7 +551,8 @@ async def test_function_result_name_resolved_from_call_history() -> None: contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] tool_user_msg = contents[-1] assert tool_user_msg.role == "user" - function_response = tool_user_msg.parts[0].function_response + function_response = _parts(tool_user_msg)[0].function_response + assert function_response is not None assert function_response.name == "get_weather" assert function_response.id == "call-42" @@ -549,7 +571,7 @@ async def test_function_result_resolved_when_call_id_was_generated() -> None: Message(role="user", contents=[Content.from_text("Go")]), Message( role="assistant", - contents=[Content.from_function_call(call_id=None, name="get_weather", arguments={})], # type: ignore[arg-type] + contents=[Content.from_function_call(call_id=cast(str, None), name="get_weather", arguments={})], ), Message( role="tool", @@ -559,9 +581,11 @@ async def test_function_result_resolved_when_call_id_was_generated() -> None: ) contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] - tool_turn = next(c for c in contents if c.role == "user" and any(p.function_response for p in c.parts)) - assert tool_turn.parts[0].function_response.name == "get_weather" - assert tool_turn.parts[0].function_response.id == generated_id + tool_turn = next(c for c in contents if c.role == "user" and any(p.function_response for p in _parts(c))) + function_response = _parts(tool_turn)[0].function_response + assert function_response is not None + assert function_response.name == "get_weather" + assert function_response.id == generated_id async def test_function_result_without_matching_call_is_skipped(caplog: pytest.LogCaptureFixture) -> None: @@ -598,7 +622,7 @@ async def test_message_with_only_unsupported_content_type_is_skipped() -> None: contents: list[types.Content] = mock.aio.models.generate_content.call_args.kwargs["contents"] assert len(contents) == 1 - assert contents[0].parts[0].text == "Follow up" + assert _parts(contents[0])[0].text == "Follow up" async def test_non_function_result_content_in_tool_message_is_skipped() -> None: @@ -819,7 +843,7 @@ async def test_prepare_config_unknown_key_is_forwarded() -> None: mock_config.return_value = MagicMock() await client.get_response( messages=[Message(role="user", contents=[Content.from_text("Hi")])], - options={"some_future_param": "value"}, + options=cast(Any, {"some_future_param": "value"}), ) assert mock_config.call_args.kwargs.get("some_future_param") == "value" @@ -970,7 +994,7 @@ async def test_agent_default_options_response_format_raw_schema_added_to_config( client, mock = _make_gemini_client() mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text='{"answer": "hello"}')])) schema = {"type": "object", "properties": {"answer": {"type": "string"}}, "required": ["answer"]} - agent = Agent(client=client, default_options={"response_format": schema}) + agent = Agent(client=cast(Any, client), default_options=cast(Any, {"response_format": schema})) await agent.run("Hi") @@ -1181,7 +1205,7 @@ def calculator(expression: str) -> str: config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] assert config.response_schema == schema assert config.tools is not None - assert config.tools[0].function_declarations[0].name == "calculator" + assert _first_function_declaration(config).name == "calculator" async def test_streaming_response_format_raw_schema_added_to_config() -> None: @@ -1304,7 +1328,7 @@ def get_weather(city: str) -> str: config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] assert config.tools is not None assert len(config.tools) == 1 - function_declaration = config.tools[0].function_declarations[0] + function_declaration = _first_function_declaration(config) assert function_declaration.name == "get_weather" @@ -1327,7 +1351,7 @@ def get_weather(city: str) -> str: config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] assert config.tools is not None - function_declaration = config.tools[0].function_declarations[0] + function_declaration = _first_function_declaration(config) assert function_declaration.name == "get_weather" @@ -1373,7 +1397,7 @@ def test_coerce_to_dict_with_json_string_literal() -> None: def _get_function_calling_mode(config: types.GenerateContentConfig) -> str: - return config.tool_config.function_calling_config.mode + return str(_function_calling_config(config).mode) def _make_dummy_tool() -> FunctionTool: @@ -1391,7 +1415,7 @@ async def _get_config_for_tool_choice(tool_choice: str) -> types.GenerateContent await client.get_response( messages=[Message(role="user", contents=[Content.from_text("Hi")])], - options={"tools": [tool], "tool_choice": tool_choice}, + options=cast(Any, {"tools": [tool], "tool_choice": tool_choice}), ) return mock.aio.models.generate_content.call_args.kwargs["config"] @@ -1429,8 +1453,9 @@ async def test_tool_choice_required_with_name_sets_allowed_function_names() -> N ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] - function_calling_config = config.tool_config.function_calling_config + function_calling_config = _function_calling_config(config) assert function_calling_config.mode == "ANY" + assert function_calling_config.allowed_function_names is not None assert "dummy" in function_calling_config.allowed_function_names @@ -1464,7 +1489,7 @@ async def test_tool_choice_auto_with_allowed_tools_uses_VALIDATED() -> None: ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] - function_calling_config = config.tool_config.function_calling_config + function_calling_config = _function_calling_config(config) assert function_calling_config.mode == "VALIDATED" assert function_calling_config.allowed_function_names == ["dummy", "other"] @@ -1484,7 +1509,7 @@ async def test_tool_choice_auto_with_empty_allowed_tools_uses_VALIDATED() -> Non ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] - function_calling_config = config.tool_config.function_calling_config + function_calling_config = _function_calling_config(config) assert function_calling_config.mode == "VALIDATED" assert function_calling_config.allowed_function_names == [] @@ -1504,7 +1529,7 @@ async def test_tool_choice_required_with_allowed_tools_uses_ANY() -> None: ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] - function_calling_config = config.tool_config.function_calling_config + function_calling_config = _function_calling_config(config) assert function_calling_config.mode == "ANY" assert function_calling_config.allowed_function_names == ["dummy"] @@ -1524,7 +1549,7 @@ async def test_tool_choice_required_function_name_takes_precedence_over_allowed_ ) config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] - function_calling_config = config.tool_config.function_calling_config + function_calling_config = _function_calling_config(config) assert function_calling_config.mode == "ANY" assert function_calling_config.allowed_function_names == ["dummy"] @@ -1620,7 +1645,9 @@ def test_get_mcp_tool_forwards_transport_kwargs() -> None: url="https://mcp.example.com/sse", headers={"Authorization": "Bearer token"}, ) - server = tool.mcp_servers[0] # type: ignore[index] + assert tool.mcp_servers is not None + server = tool.mcp_servers[0] + assert server.streamable_http_transport is not None assert server.streamable_http_transport.headers == {"Authorization": "Bearer token"} @@ -1637,7 +1664,7 @@ async def test_types_tool_passed_in_tools_list_is_forwarded() -> None: config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"] assert config.tools is not None - assert any(tool.google_search for tool in config.tools) + assert any(cast(types.Tool, tool).google_search for tool in config.tools) async def test_function_response_part_in_response_mapped_to_content() -> None: diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 51b57913729..921231a973a 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -50,9 +50,9 @@ ) from _copilot_import_error if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar + from typing_extensions import TypeVar # pragma: no cover DEFAULT_TIMEOUT_SECONDS: float = 60.0 @@ -202,6 +202,9 @@ class GitHubCopilotOptions(TypedDict, total=False): files beyond the default locations. """ + base_directory: str + """Directory where the CLI stores session state, configuration, and other persistent data.""" + on_function_approval: FunctionApprovalCallback """Approval callback for ``FunctionTool`` instances declared with ``approval_mode="always_require"``. The callback is awaited (sync or async) @@ -441,7 +444,7 @@ def run( session: AgentSession | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, options: OptionsT | None = None, - **kwargs: Any, # type: ignore[override] + **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -732,7 +735,7 @@ async def _run_before_providers( if isinstance(provider, HistoryProvider) and not provider.load_messages: continue await provider.before_run( - agent=self, # type: ignore[arg-type] + agent=self, session=session, context=session_context, state=session.state.setdefault(provider.source_id, {}), @@ -781,7 +784,7 @@ def _prepare_tools( if isinstance(tool, CopilotTool): copilot_tools.append(tool) elif isinstance(tool, FunctionTool): - copilot_tools.append(self._tool_to_copilot_tool(tool)) # type: ignore + copilot_tools.append(self._tool_to_copilot_tool(tool)) elif isinstance(tool, MutableMapping): copilot_tools.append(tool) # type: ignore[arg-type] # Note: Other tool types (e.g., dict-based hosted tools) are skipped diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index 518d5758207..39a7eef387b 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -4,8 +4,9 @@ import os import unittest.mock +from collections.abc import Awaitable, Callable, Sequence from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -37,6 +38,11 @@ from agent_framework_github_copilot import GitHubCopilotAgent, GitHubCopilotOptions +def copilot_options(options: GitHubCopilotOptions) -> GitHubCopilotOptions: + """Return GitHub Copilot options with concrete TypedDict typing for tests.""" + return options + + def create_session_event( event_type: SessionEventType, content: str | None = None, @@ -136,9 +142,7 @@ def test_init_without_client(self) -> None: def test_init_with_default_options(self) -> None: """Test initialization with default_options parameter.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"model": "claude-sonnet-4", "timeout": 120} - ) + agent = GitHubCopilotAgent(default_options=copilot_options({"model": "claude-sonnet-4", "timeout": 120})) assert agent._settings["model"] == "claude-sonnet-4" # type: ignore assert agent._settings["timeout"] == 120 # type: ignore @@ -161,8 +165,10 @@ def test_init_with_instructions_parameter(self) -> None: def test_init_with_system_message_in_default_options(self) -> None: """Test initialization with system_message object in default_options.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"system_message": {"mode": "append", "content": "You are a helpful assistant."}} + agent = GitHubCopilotAgent( + default_options=copilot_options({ + "system_message": {"mode": "append", "content": "You are a helpful assistant."} + }) ) assert agent._default_options.get("system_message") == { # type: ignore "mode": "append", @@ -171,8 +177,8 @@ def test_init_with_system_message_in_default_options(self) -> None: def test_init_with_system_message_replace_mode(self) -> None: """Test initialization with system_message in replace mode.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"system_message": {"mode": "replace", "content": "Custom system prompt."}} + agent = GitHubCopilotAgent( + default_options=copilot_options({"system_message": {"mode": "replace", "content": "Custom system prompt."}}) ) assert agent._default_options.get("system_message") == { # type: ignore "mode": "replace", @@ -181,9 +187,11 @@ def test_init_with_system_message_replace_mode(self) -> None: def test_instructions_parameter_takes_precedence_for_content(self) -> None: """Test that direct instructions parameter takes precedence for content but preserves mode.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( instructions="Direct instructions", - default_options={"system_message": {"mode": "replace", "content": "Options system_message"}}, + default_options=copilot_options({ + "system_message": {"mode": "replace", "content": "Options system_message"} + }), ) assert agent._default_options.get("system_message") == { # type: ignore "mode": "replace", @@ -200,9 +208,7 @@ def test_instructions_parameter_defaults_to_append_mode(self) -> None: def test_default_options_includes_model_for_telemetry(self) -> None: """Test that default_options merges model from settings for AgentTelemetryLayer span attributes.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"model": "claude-sonnet-4-5", "timeout": 120} - ) + agent = GitHubCopilotAgent(default_options=copilot_options({"model": "claude-sonnet-4-5", "timeout": 120})) opts = agent.default_options assert opts["model"] == "claude-sonnet-4-5" assert "timeout" not in opts # timeout is extracted into _settings, not returned in default_options @@ -216,16 +222,14 @@ def test_default_options_without_model_configured(self) -> None: def test_default_options_returns_independent_copy(self) -> None: """Test that mutating the returned dict does not affect internal state.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(default_options={"model": "gpt-5.1-mini"}) + agent = GitHubCopilotAgent(default_options=copilot_options({"model": "gpt-5.1-mini"})) opts = agent.default_options opts["model"] = "mutated" assert agent._settings.get("model") == "gpt-5.1-mini" def test_init_stores_instruction_directories(self) -> None: """Test that instruction_directories are stored on the agent instance.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"instruction_directories": ["/my/instructions"]} - ) + agent = GitHubCopilotAgent(default_options=copilot_options({"instruction_directories": ["/my/instructions"]})) assert agent._instruction_directories == ["/my/instructions"] # type: ignore def test_init_without_instruction_directories(self) -> None: @@ -306,8 +310,8 @@ async def test_start_creates_client_with_options(self) -> None: mock_client.start = AsyncMock() MockClient.return_value = mock_client - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"cli_path": "/custom/path", "log_level": "debug"} + agent = GitHubCopilotAgent( + default_options=copilot_options({"cli_path": "/custom/path", "log_level": "debug"}) ) await agent.start() @@ -322,9 +326,7 @@ async def test_start_passes_base_directory_to_client(self) -> None: mock_client.start = AsyncMock() MockClient.return_value = mock_client - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"base_directory": "/custom/copilot/home"} - ) + agent = GitHubCopilotAgent(default_options=copilot_options({"base_directory": "/custom/copilot/home"})) await agent.start() kwargs = MockClient.call_args.kwargs @@ -422,7 +424,7 @@ async def test_run_with_runtime_options( mock_session.send_and_wait.return_value = assistant_message_event agent = GitHubCopilotAgent(client=mock_client) - response = await agent.run("Hello", options={"timeout": 30}) + response = await agent.run("Hello", options=cast(Any, {"timeout": 30})) assert isinstance(response, AgentResponse) @@ -961,9 +963,7 @@ async def test_session_config_includes_model( mock_session: MagicMock, ) -> None: """Test that session config includes model setting.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - client=mock_client, default_options={"model": "claude-sonnet-4"} - ) + agent = GitHubCopilotAgent(client=mock_client, default_options=copilot_options({"model": "claude-sonnet-4"})) await agent.start() await agent._get_or_create_session(AgentSession()) # type: ignore @@ -1003,9 +1003,7 @@ async def test_runtime_options_take_precedence_over_default( ) await agent.start() - runtime_options: GitHubCopilotOptions = { - "system_message": {"mode": "replace", "content": "Runtime instructions"} - } + runtime_options: dict[str, Any] = {"system_message": {"mode": "replace", "content": "Runtime instructions"}} await agent._get_or_create_session( # type: ignore AgentSession(), runtime_options=runtime_options, @@ -1066,10 +1064,10 @@ def my_tool(arg: str) -> str: """A test tool.""" return arg - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, tools=[my_tool], - default_options={"on_permission_request": my_handler}, + default_options=copilot_options({"on_permission_request": my_handler}), ) await agent.start() @@ -1090,9 +1088,9 @@ async def test_instruction_directories_passed_to_create_session( mock_session: MagicMock, ) -> None: """Test that instruction_directories are passed through to create_session.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"instruction_directories": ["/path/to/instructions", "/other/path"]}, + default_options=copilot_options({"instruction_directories": ["/path/to/instructions", "/other/path"]}), ) await agent.start() @@ -1108,9 +1106,9 @@ async def test_instruction_directories_runtime_override( mock_session: MagicMock, ) -> None: """Test that runtime instruction_directories take precedence over defaults.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"instruction_directories": ["/default/path"]}, + default_options=copilot_options({"instruction_directories": ["/default/path"]}), ) await agent.start() @@ -1142,9 +1140,9 @@ async def test_instruction_directories_empty_list_clears_defaults( mock_session: MagicMock, ) -> None: """Test that an explicit empty list at runtime clears the agent-level defaults.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"instruction_directories": ["/default/path"]}, + default_options=copilot_options({"instruction_directories": ["/default/path"]}), ) await agent.start() @@ -1161,9 +1159,9 @@ async def test_instruction_directories_override_on_resumed_session( mock_session: MagicMock, ) -> None: """Test that instruction_directories override works on resumed sessions.""" - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"instruction_directories": ["/default/path"]}, + default_options=copilot_options({"instruction_directories": ["/default/path"]}), ) await agent.start() @@ -1204,9 +1202,9 @@ async def test_mcp_servers_passed_to_create_session( }, } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"mcp_servers": mcp_servers}, + default_options=copilot_options({"mcp_servers": mcp_servers}), ) await agent.start() @@ -1237,9 +1235,9 @@ async def test_mcp_servers_passed_to_resume_session( }, } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"mcp_servers": mcp_servers}, + default_options=copilot_options({"mcp_servers": mcp_servers}), ) await agent.start() @@ -1286,9 +1284,9 @@ async def test_provider_passed_to_create_session( "bearer_token": "test-token", } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"provider": provider}, + default_options=copilot_options({"provider": provider}), ) await agent.start() @@ -1313,9 +1311,9 @@ async def test_provider_passed_to_resume_session( "bearer_token": "test-token", } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"provider": provider}, + default_options=copilot_options({"provider": provider}), ) await agent.start() @@ -1378,9 +1376,9 @@ async def test_runtime_provider_takes_precedence( "api_key": "runtime-key", } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"provider": default_provider}, + default_options=copilot_options({"provider": default_provider}), ) await agent.start() @@ -1407,9 +1405,9 @@ async def test_provider_not_leaked_into_default_options( "bearer_token": "test-token", } - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"provider": provider, "model": "gpt-5"}, + default_options=copilot_options({"provider": provider, "model": "gpt-5"}), ) assert "provider" not in agent._default_options @@ -1441,14 +1439,14 @@ def my_tool(arg: str) -> str: """A test tool.""" return arg - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, tools=[my_tool], - default_options={ + default_options=copilot_options({ "model": "gpt-5", "provider": provider, "mcp_servers": mcp_servers, - }, + }), ) await agent.start() @@ -1539,6 +1537,7 @@ def failing_tool(arg: str) -> str: assert isinstance(result, ToolResult) assert result.result_type == "failure" assert "Something went wrong" in result.text_result_for_llm + assert result.error is not None assert "Something went wrong" in result.error async def test_tool_handler_rejects_raw_dict_invocation( @@ -1666,7 +1665,8 @@ def dangerous(path: str) -> str: agent = GitHubCopilotAgent(client=mock_client) copilot_tool = agent._tool_to_copilot_tool(dangerous) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"path": "/critical"})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"path": "/critical"})) assert invocations == [] assert result.result_type == "failure" @@ -1695,11 +1695,12 @@ def dangerous(path: str) -> str: agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_function_approval": deny}, + default_options=copilot_options({"on_function_approval": deny}), ) copilot_tool = agent._tool_to_copilot_tool(dangerous) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"path": "/critical"})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"path": "/critical"})) assert invocations == [] assert len(seen) == 1 @@ -1726,11 +1727,12 @@ def guarded(x: int) -> str: agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_function_approval": approve}, + default_options=copilot_options({"on_function_approval": approve}), ) copilot_tool = agent._tool_to_copilot_tool(guarded) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"x": 42})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"x": 42})) assert result.result_type == "success" assert result.text_result_for_llm == "result=42" @@ -1752,11 +1754,12 @@ def guarded(x: int) -> str: agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_function_approval": approve}, + default_options=copilot_options({"on_function_approval": approve}), ) copilot_tool = agent._tool_to_copilot_tool(guarded) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"x": 7})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"x": 7})) assert result.result_type == "success" assert result.text_result_for_llm == "async=7" @@ -1781,11 +1784,12 @@ def dangerous(x: int) -> str: agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_function_approval": boom}, + default_options=copilot_options({"on_function_approval": boom}), ) copilot_tool = agent._tool_to_copilot_tool(dangerous) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"x": 1})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"x": 1})) assert invocations == [] assert result.result_type == "failure" @@ -1811,11 +1815,12 @@ def safe(x: int) -> str: agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_function_approval": approve}, + default_options=copilot_options({"on_function_approval": approve}), ) copilot_tool = agent._tool_to_copilot_tool(safe) # type: ignore[reportPrivateUsage] - result = await copilot_tool.handler(ToolInvocation(arguments={"x": 5})) + handler = cast("Callable[[ToolInvocation], Awaitable[ToolResult]]", copilot_tool.handler) + result = await handler(ToolInvocation(arguments={"x": 5})) assert callback_calls == [] assert result.result_type == "success" @@ -1897,9 +1902,7 @@ def approve_shell(request: PermissionRequest, context: dict[str, str]) -> Permis return PermissionDecisionApproveOnce() return PermissionDecisionDeniedInteractivelyByUser() - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( - default_options={"on_permission_request": approve_shell} - ) + agent = GitHubCopilotAgent(default_options=copilot_options({"on_permission_request": approve_shell})) assert agent._permission_handler is not None # type: ignore async def test_session_config_includes_permission_handler( @@ -1917,9 +1920,9 @@ def approve_shell_read(request: PermissionRequest, context: dict[str, str]) -> P return PermissionDecisionApproveOnce() return PermissionDecisionDeniedInteractivelyByUser() - agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( + agent = GitHubCopilotAgent( client=mock_client, - default_options={"on_permission_request": approve_shell_read}, + default_options=copilot_options({"on_permission_request": approve_shell_read}), ) await agent.start() @@ -2362,17 +2365,26 @@ async def after_run( ) -> None: self.after_run_called = True - async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]: + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: return [] - async def save_messages(self, *, session_id: str, messages: list[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: pass skipped_provider = StubHistoryProvider(load_messages=False) active_provider = StubHistoryProvider(load_messages=True) # Use unique source_ids - skipped_provider._source_id = "skipped-history" - active_provider._source_id = "active-history" + object.__setattr__(skipped_provider, "_source_id", "skipped-history") + object.__setattr__(active_provider, "_source_id", "active-history") agent = GitHubCopilotAgent(client=mock_client, context_providers=[skipped_provider, active_provider]) session = agent.create_session() @@ -2475,7 +2487,7 @@ async def after_run( provider = OptionsObserverProvider() agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider]) session = agent.create_session() - await agent.run("Hello", session=session, options={"timeout": 120}) + await agent.run("Hello", session=session, options=cast(Any, {"timeout": 120})) assert observed_options.get("timeout") == 120 @@ -2483,13 +2495,17 @@ async def test_runtime_on_function_approval_rejected(self, mock_client: MagicMoc """Passing on_function_approval at runtime must raise rather than be silently ignored.""" agent = GitHubCopilotAgent(client=mock_client) with pytest.raises(ValueError, match="on_function_approval"): - await agent.run("hello", options={"on_function_approval": lambda _c: True}) + await agent.run("hello", options=cast(Any, {"on_function_approval": lambda _c: True})) async def test_runtime_on_function_approval_rejected_streaming(self, mock_client: MagicMock) -> None: """Passing on_function_approval at runtime must raise on the streaming path too.""" agent = GitHubCopilotAgent(client=mock_client) with pytest.raises(ValueError, match="on_function_approval"): - async for _ in agent.run("hello", stream=True, options={"on_function_approval": lambda _c: True}): + async for _ in agent.run( + "hello", + stream=True, + options=cast(Any, {"on_function_approval": lambda _c: True}), + ): pass async def test_provider_tools_forwarded_to_session( @@ -2744,7 +2760,7 @@ async def test_integration_run_with_simple_prompt_returns_response() -> None: """Integration test: basic non-streaming response.""" agent = GitHubCopilotAgent( instructions="You are a helpful assistant. Keep your answers short.", - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: @@ -2766,7 +2782,7 @@ async def test_integration_run_streaming_returns_updates() -> None: """Integration test: streaming response yields updates.""" agent = GitHubCopilotAgent( instructions="You are a helpful assistant. Keep your answers short.", - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: @@ -2791,7 +2807,7 @@ async def test_integration_run_with_function_tool_invokes_tool() -> None: agent = GitHubCopilotAgent( instructions="You are a helpful weather agent. Use the get_weather tool to answer weather questions.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: @@ -2813,7 +2829,7 @@ async def test_integration_run_with_session_maintains_context() -> None: """Integration test: session maintains conversation context across turns.""" agent = GitHubCopilotAgent( instructions="You are a helpful assistant. Keep your answers short.", - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: @@ -2838,7 +2854,7 @@ async def test_integration_run_with_session_resume_continues_conversation() -> N """Integration test: session can be resumed by ID.""" agent = GitHubCopilotAgent( instructions="You are a helpful assistant. Keep your answers short.", - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: @@ -2867,7 +2883,7 @@ async def test_integration_run_with_shell_permissions_executes_command() -> None """Integration test: shell commands can be executed with permission handler.""" agent = GitHubCopilotAgent( instructions="You are a helpful assistant that can execute shell commands.", - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=copilot_options({"on_permission_request": PermissionHandler.approve_all}), ) async with agent: diff --git a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py index 03e3c2269cb..ed8f6262c16 100644 --- a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py +++ b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py @@ -13,10 +13,10 @@ import sys import threading import time -from collections.abc import Awaitable, Callable, Mapping, MutableSequence +from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, cast import pytest from agent_framework import ( @@ -216,7 +216,7 @@ def _invoke_tool(self, name: str, **kwargs: Any) -> Any: result = callback(**kwargs) if inspect.isawaitable(result): - return _run_in_thread(lambda: asyncio.run(result)) + return _run_in_thread(lambda: asyncio.run(cast(Coroutine[Any, Any, Any], result))) return result def run(self, code: str) -> _FakeResult: @@ -264,10 +264,11 @@ def run(self, code: str) -> _FakeResult: if 'Path("/output/report.txt").write_text("artifact", encoding="utf-8")' in code: if self.output_dir is None: raise AssertionError("Expected output directory for delayed output test.") + output_dir = self.output_dir def _write_file() -> None: time.sleep(0.15) - Path(self.output_dir, "report.txt").write_text("artifact", encoding="utf-8") + Path(output_dir, "report.txt").write_text("artifact", encoding="utf-8") writer_thread = threading.Thread(target=_write_file) writer_thread.start() @@ -317,7 +318,7 @@ def __init__(self) -> None: def _inner_get_response( self, *, - messages: MutableSequence[Message], + messages: Sequence[Message], stream: bool, options: Mapping[str, Any], **kwargs: Any, @@ -789,7 +790,7 @@ async def test_provider_injects_run_scoped_execute_code_tool() -> None: context = _FakeSessionContext(tools=[dangerous_compute]) state: dict[str, Any] = {} - await provider.before_run(agent=object(), session=None, context=context, state=state) + await provider.before_run(agent=object(), session=None, context=cast(Any, context), state=state) assert context.options["tools"] == [dangerous_compute] assert len(context.instructions) == 1 @@ -850,7 +851,7 @@ async def test_provider_run_tool_writes_files_with_real_sandbox(tmp_path: Path) context = _FakeSessionContext() state: dict[str, Any] = {} - await provider.before_run(agent=object(), session=None, context=context, state=state) + await provider.before_run(agent=object(), session=None, context=cast(Any, context), state=state) run_tool = context.tools[0][1][0] assert isinstance(run_tool, HyperlightExecuteCodeTool) @@ -903,7 +904,7 @@ async def test_provider_run_tool_pings_bing_with_real_sandbox() -> None: context = _FakeSessionContext() state: dict[str, Any] = {} - await provider.before_run(agent=object(), session=None, context=context, state=state) + await provider.before_run(agent=object(), session=None, context=cast(Any, context), state=state) run_tool = context.tools[0][1][0] assert isinstance(run_tool, HyperlightExecuteCodeTool) @@ -1042,7 +1043,7 @@ async def test_output_dir_cleared_between_invocations() -> None: provider = HyperlightCodeActProvider(workspace_root=Path(__file__).parent) context = _FakeSessionContext() state: dict[str, Any] = {} - await provider.before_run(agent=object(), session=None, context=context, state=state) + await provider.before_run(agent=object(), session=None, context=cast(Any, context), state=state) run_tool = context.tools[0][1][0] assert isinstance(run_tool, HyperlightExecuteCodeTool) @@ -1076,7 +1077,7 @@ async def test_run_code_does_not_block_event_loop() -> None: provider = HyperlightCodeActProvider() context = _FakeSessionContext() state: dict[str, Any] = {} - await provider.before_run(agent=object(), session=None, context=context, state=state) + await provider.before_run(agent=object(), session=None, context=cast(Any, context), state=state) run_tool = context.tools[0][1][0] assert isinstance(run_tool, HyperlightExecuteCodeTool) @@ -1177,8 +1178,9 @@ async def test_sandbox_calls_are_pinned_to_owning_worker_thread( sandbox = _ThreadAffinityFakeSandbox.instances[0] # All sandbox-touching calls must have stayed on a single owning thread, distinct from the # caller thread that asyncio.to_thread used for dispatch. - assert sandbox.thread_ids == {sandbox._owner_thread} - assert sandbox._owner_thread != threading.get_ident() + sandbox_with_thread_data = cast(Any, sandbox) + assert sandbox_with_thread_data.thread_ids == {sandbox_with_thread_data._owner_thread} + assert sandbox_with_thread_data._owner_thread != threading.get_ident() async def test_sandbox_owner_thread_persists_across_dispatch_threads( diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index 78f64ff5651..5c253b06777 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -33,7 +33,7 @@ def loads(self, obj: str | bytes | bytearray, /) -> object: ... @lru_cache(maxsize=1) def _get_orjson() -> _OrjsonModule | None: try: - import orjson as runtime_orjson # pyright: ignore[reportMissingImports] + import orjson as runtime_orjson except ImportError: return None return cast(_OrjsonModule, runtime_orjson) @@ -273,7 +273,7 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max for p in parquet_files: try: - import pyarrow.parquet as pq # type: ignore[reportMissingImports] + import pyarrow.parquet as pq pq_any = cast(Any, pq) table: Any = pq_any.read_table(p) diff --git a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py index 3da1121910e..688f89825c0 100644 --- a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py +++ b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py @@ -7,8 +7,8 @@ import importlib.metadata from agent_framework.observability import enable_instrumentation -from agentlightning.tracer import ( # type: ignore[reportMissingImports] - AgentOpsTracer, # type: ignore[reportMissingImports, import-not-found] +from agentlightning.tracer import ( + AgentOpsTracer, ) try: @@ -17,7 +17,7 @@ __version__ = "0.0.0" # Fallback for development mode -class AgentFrameworkTracer(AgentOpsTracer): # type: ignore +class AgentFrameworkTracer(AgentOpsTracer): """Tracer for Agent-framework. Tracer that enables OpenTelemetry observability for the Agent-framework, @@ -27,11 +27,11 @@ class AgentFrameworkTracer(AgentOpsTracer): # type: ignore def init(self) -> None: """Initialize the agent-framework-lab-lightning for training.""" enable_instrumentation() - super().init() # pyright: ignore[reportUnknownMemberType] + super().init() def teardown(self) -> None: """Teardown the agent-framework-lab-lightning for training.""" - super().teardown() # pyright: ignore[reportUnknownMemberType] + super().teardown() __all__: list[str] = ["AgentFrameworkTracer"] diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 5b1390c3dc3..ccae03a6244 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -10,7 +10,7 @@ from agent_framework._types import Message from loguru import logger from pydantic import BaseModel -from tau2.data_model.message import ( # type: ignore[import-untyped] +from tau2.data_model.message import ( AssistantMessage, SystemMessage, ToolCall, @@ -20,9 +20,9 @@ from tau2.data_model.message import ( Message as Tau2Message, ) -from tau2.data_model.tasks import EnvFunctionCall, InitializationData # type: ignore[import-untyped] -from tau2.environment.environment import Environment # type: ignore[import-untyped] -from tau2.environment.tool import Tool # type: ignore[import-untyped] +from tau2.data_model.tasks import EnvFunctionCall, InitializationData +from tau2.environment.environment import Environment +from tau2.environment.tool import Tool _original_set_state = Environment.set_state diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index b78131db664..084c52598de 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -20,17 +20,17 @@ WorkflowContext, ) from loguru import logger -from tau2.data_model.simulation import SimulationRun, TerminationReason # type: ignore[import-untyped] -from tau2.data_model.tasks import Task # type: ignore[import-untyped] -from tau2.domains.airline.environment import get_environment # type: ignore[import-untyped] -from tau2.evaluator.evaluator import EvaluationType, RewardInfo, evaluate_simulation # type: ignore[import-untyped] -from tau2.user.user_simulator import ( # type: ignore[import-untyped] +from tau2.data_model.simulation import SimulationRun, TerminationReason +from tau2.data_model.tasks import Task +from tau2.domains.airline.environment import get_environment +from tau2.evaluator.evaluator import EvaluationType, RewardInfo, evaluate_simulation +from tau2.user.user_simulator import ( OUT_OF_SCOPE, STOP, TRANSFER, get_global_user_sim_guidelines, ) -from tau2.utils.utils import get_now # type: ignore[import-untyped] +from tau2.utils.utils import get_now from ._message_utils import flip_messages, log_messages from ._sliding_window import SlidingWindowHistoryProvider @@ -372,7 +372,7 @@ async def run( session_state: dict[str, Any] = self._assistant_executor._session.state # type: ignore all_messages: list[Message] = list( session_state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", []) - ) # type: ignore + ) full_conversation = [first_message, *all_messages] if self._final_user_message is not None: full_conversation.extend(self._final_user_message) diff --git a/python/packages/lab/tau2/tests/test_message_utils.py b/python/packages/lab/tau2/tests/test_message_utils.py index 601c7168548..4300bd70d56 100644 --- a/python/packages/lab/tau2/tests/test_message_utils.py +++ b/python/packages/lab/tau2/tests/test_message_utils.py @@ -156,7 +156,7 @@ def test_flip_messages_mixed_conversation(): def test_flip_messages_empty_list(): """Test flipping empty message list.""" - messages = [] + messages: list[Message] = [] flipped = flip_messages(messages) assert len(flipped) == 0 diff --git a/python/packages/lab/tau2/tests/test_tau2_utils.py b/python/packages/lab/tau2/tests/test_tau2_utils.py index 671ce0a6959..f6643e9db6a 100644 --- a/python/packages/lab/tau2/tests/test_tau2_utils.py +++ b/python/packages/lab/tau2/tests/test_tau2_utils.py @@ -44,7 +44,7 @@ def test_convert_tau2_tool_to_function_tool_basic(): tau2_tool = _DummyTau2Tool(name="lookup_booking", description="Lookup booking by id.") # Convert the tool - tool = convert_tau2_tool_to_function_tool(tau2_tool) + tool = convert_tau2_tool_to_function_tool(tau2_tool) # ty: ignore[invalid-argument-type] # pyrefly: ignore[bad-argument-type] # pyright: ignore[reportArgumentType] # Verify the conversion assert isinstance(tool, FunctionTool) @@ -52,6 +52,7 @@ def test_convert_tau2_tool_to_function_tool_basic(): assert tool.description == tau2_tool._get_description() assert tool.input_model == tau2_tool.params + assert tool.func is not None result = tool.func(param="ABC123") assert isinstance(result, _DummyToolResult) assert result.output == "ABC123" @@ -67,7 +68,7 @@ def test_convert_tau2_tool_to_function_tool_multiple_tools(): ] # Convert multiple tools - function_tools = [convert_tau2_tool_to_function_tool(tool) for tool in tools] + function_tools = [convert_tau2_tool_to_function_tool(tool) for tool in tools] # ty: ignore[invalid-argument-type] # pyrefly: ignore[bad-argument-type] # pyright: ignore[reportArgumentType] # Verify all conversions for tool, tau2_tool in zip(function_tools, tools, strict=False): @@ -158,9 +159,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_result( def test_convert_agent_framework_messages_to_tau2_messages_with_error(): """Test converting function result with error.""" - function_result = Content.from_function_result( - call_id="call_456", result="Error occurred", exception=Exception("Test error") - ) + function_result = Content.from_function_result(call_id="call_456", result="Error occurred", exception="Test error") messages = [Message(role="tool", contents=[function_result])] diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index 071c054824b..08c9d9bc008 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -162,11 +162,7 @@ async def before_run( elif isinstance(search_response, dict): results_field = search_response.get("results") if isinstance(results_field, list): - current_memories = [ - item - for item in results_field - if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType] - ] + current_memories = [item for item in results_field if isinstance(item, dict)] else: logger.warning( "Unexpected Mem0 search response format: %s", diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index a047af1638f..2fa1739c231 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -3,6 +3,7 @@ from __future__ import annotations +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -97,7 +98,10 @@ async def test_memories_added_to_context(self, mock_mem0_client: AsyncMock) -> N ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_mem0_client.search.assert_awaited_once() @@ -115,7 +119,10 @@ async def test_empty_input_skips_search(self, mock_mem0_client: AsyncMock) -> No ctx = SessionContext(input_messages=[Message(role="user", contents=[""])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_mem0_client.search.assert_not_awaited() @@ -129,7 +136,10 @@ async def test_empty_search_results_no_messages(self, mock_mem0_client: AsyncMoc ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert "mem0" not in ctx.context_messages @@ -141,7 +151,12 @@ async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") with pytest.raises(ValueError, match="At least one of the filters"): - await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + await provider.before_run( + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state, + ) # type: ignore[arg-type] async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: """Search response in v1.1 dict format with 'results' key.""" @@ -151,7 +166,10 @@ async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None: ctx = SessionContext(input_messages=[Message(role="user", contents=["test"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] added = ctx.context_messages["mem0"] @@ -171,7 +189,10 @@ async def test_search_query_combines_input_messages(self, mock_mem0_client: Asyn ) await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] call_kwargs = mock_mem0_client.search.call_args.kwargs @@ -185,7 +206,10 @@ async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: Async ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] call_kwargs = mock_oss_mem0_client.search.call_args.kwargs @@ -198,12 +222,7 @@ async def test_oss_client_all_scoping_params_except_app_id(self, mock_oss_mem0_c """OSS client with all scoping parameters passes them as isolated concurrent kwargs.""" mock_oss_mem0_client.search.return_value = [] - provider = Mem0ContextProvider( - source_id="mem0", - mem0_client=mock_oss_mem0_client, - user_id="u1", - agent_id="a1" - ) + provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1") mock_context = MagicMock(spec=SessionContext) mock_msg = MagicMock() @@ -262,7 +281,10 @@ async def test_stores_input_and_response(self, mock_mem0_client: AsyncMock) -> N ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["answer"])]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_mem0_client.add.assert_awaited_once() @@ -288,7 +310,10 @@ async def test_only_stores_user_assistant_system(self, mock_mem0_client: AsyncMo ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] call_kwargs = mock_mem0_client.add.call_args.kwargs @@ -310,7 +335,10 @@ async def test_skips_empty_messages(self, mock_mem0_client: AsyncMock) -> None: ctx._response = AgentResponse(messages=[]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_mem0_client.add.assert_not_awaited() @@ -323,7 +351,10 @@ async def test_no_run_id_in_storage(self, mock_mem0_client: AsyncMock) -> None: ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hey"])]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert "run_id" not in mock_mem0_client.add.call_args.kwargs @@ -336,7 +367,12 @@ async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None: ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hey"])]) with pytest.raises(ValueError, match="At least one of the filters"): - await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type] + await provider.after_run( + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state, + ) # type: ignore[arg-type] async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncMock) -> None: """application_id is passed in filters.""" @@ -348,7 +384,10 @@ async def test_stores_with_application_id_filters(self, mock_mem0_client: AsyncM ctx._response = AgentResponse(messages=[]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"} diff --git a/python/packages/mistral/agent_framework_mistral/_embedding_client.py b/python/packages/mistral/agent_framework_mistral/_embedding_client.py index 4b07d7d120c..5cb59fe5419 100644 --- a/python/packages/mistral/agent_framework_mistral/_embedding_client.py +++ b/python/packages/mistral/agent_framework_mistral/_embedding_client.py @@ -20,9 +20,9 @@ from mistralai.client import Mistral if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover logger = logging.getLogger("agent_framework.mistral") diff --git a/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py b/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py index cecd03b3e28..e593582feb4 100644 --- a/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py +++ b/python/packages/mistral/tests/mistral/test_mistral_embedding_client.py @@ -184,7 +184,7 @@ async def test_mistral_embedding_get_embeddings_no_model_raises() -> None: mock_cls.return_value = mock_client client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key") - client.model = None # type: ignore[assignment] + client.model = None # type: ignore[assignment] # ty: ignore[invalid-assignment] with pytest.raises(ValueError, match="model is required"): await client.get_embeddings(["hello"]) diff --git a/python/packages/monty/agent_framework_monty/_monty_bridge.py b/python/packages/monty/agent_framework_monty/_monty_bridge.py index 1d9cd46c400..798aad4ec46 100644 --- a/python/packages/monty/agent_framework_monty/_monty_bridge.py +++ b/python/packages/monty/agent_framework_monty/_monty_bridge.py @@ -186,7 +186,7 @@ def load_monty() -> Any: ``FunctionSnapshot``, ``FutureSnapshot``, ``NameLookupSnapshot`` from it. """ try: - import pydantic_monty # type: ignore[import-not-found] + import pydantic_monty except ImportError as exc: raise RuntimeError( "The `pydantic-monty` package is required to execute Monty CodeAct code. " diff --git a/python/packages/monty/tests/monty/test_monty_codeact.py b/python/packages/monty/tests/monty/test_monty_codeact.py index 43c8e3acacc..3a0428100c5 100644 --- a/python/packages/monty/tests/monty/test_monty_codeact.py +++ b/python/packages/monty/tests/monty/test_monty_codeact.py @@ -143,11 +143,11 @@ def start(self, *, print_callback: Any) -> Any: def fake_monty_module(monkeypatch: pytest.MonkeyPatch) -> Iterator[None]: """Install a fake ``pydantic_monty`` module for the duration of each test.""" fake = types.ModuleType("pydantic_monty") - fake.Monty = _FakeMonty # type: ignore[attr-defined] - fake.MontyComplete = _FakeMontyComplete # type: ignore[attr-defined] - fake.FunctionSnapshot = _FakeFunctionSnapshot # type: ignore[attr-defined] - fake.FutureSnapshot = _FakeFutureSnapshot # type: ignore[attr-defined] - fake.NameLookupSnapshot = _FakeNameLookupSnapshot # type: ignore[attr-defined] + fake.Monty = _FakeMonty # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + fake.MontyComplete = _FakeMontyComplete # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + fake.FunctionSnapshot = _FakeFunctionSnapshot # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + fake.FutureSnapshot = _FakeFutureSnapshot # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] + fake.NameLookupSnapshot = _FakeNameLookupSnapshot # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] monkeypatch.setitem(sys.modules, "pydantic_monty", fake) _current_script[0] = None diff --git a/python/packages/monty/tests/monty/test_monty_codeact_integration.py b/python/packages/monty/tests/monty/test_monty_codeact_integration.py index 2728936c9de..2eeffb02a67 100644 --- a/python/packages/monty/tests/monty/test_monty_codeact_integration.py +++ b/python/packages/monty/tests/monty/test_monty_codeact_integration.py @@ -524,7 +524,7 @@ async def test_approval_required_tool_gates_execute_code_end_to_end() -> None: async def test_agent_runs_monty_codeact_end_to_end() -> None: """A fake chat client emits one execute_code tool call; Monty runs it end-to-end.""" - from collections.abc import Awaitable, Mapping, MutableSequence + from collections.abc import Awaitable, Mapping, Sequence from agent_framework import ( BaseChatClient, @@ -543,7 +543,7 @@ def __init__(self) -> None: def _inner_get_response( self, *, - messages: MutableSequence[Message], + messages: Sequence[Message], stream: bool, options: Mapping[str, Any], **kwargs: Any, diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index ee517fe53fe..931011d74c1 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -44,19 +44,19 @@ from pydantic import BaseModel if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover __all__ = ["OllamaChatClient", "OllamaChatOptions"] diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index e921e6172a2..cb8a004efee 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -19,9 +19,9 @@ from ollama import AsyncClient if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover logger = logging.getLogger("agent_framework.ollama") @@ -120,7 +120,7 @@ async def get_embeddings( self, values: Sequence[str], *, - options: OllamaEmbeddingOptionsT | None = None, # type: ignore + options: OllamaEmbeddingOptionsT | None = None, ) -> GeneratedEmbeddings[list[float], OllamaEmbeddingOptionsT]: """Call the Ollama embed API. @@ -156,7 +156,7 @@ async def get_embeddings( Embedding( vector=list(emb), dimensions=len(emb), - model=response.get("model") or model, # type: ignore[assignment] + model=response.get("model") or model, ) for emb in response.get("embeddings", []) ] diff --git a/python/packages/ollama/tests/test_ollama_chat_client.py b/python/packages/ollama/tests/test_ollama_chat_client.py index 98ec78475d1..f3061f60369 100644 --- a/python/packages/ollama/tests/test_ollama_chat_client.py +++ b/python/packages/ollama/tests/test_ollama_chat_client.py @@ -2,7 +2,7 @@ import os from collections.abc import AsyncIterable -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -123,7 +123,7 @@ def mock_streaming_chat_completion_tool_call() -> AsyncStream[OllamaChatResponse message=OllamaMessage( content="", role="assistant", - tool_calls=[{"function": {"name": "hello_world", "arguments": {"arg1": "value1"}}}], + tool_calls=cast(Any, [{"function": {"name": "hello_world", "arguments": {"arg1": "value1"}}}]), ), model="test", ) @@ -138,7 +138,7 @@ def mock_chat_completion_tool_call() -> OllamaChatResponse: message=OllamaMessage( content="", role="assistant", - tool_calls=[{"function": {"name": "hello_world", "arguments": {"arg1": "value1"}}}], + tool_calls=cast(Any, [{"function": {"name": "hello_world", "arguments": {"arg1": "value1"}}}]), ), model="test", created_at="2024-01-01T00:00:00Z", @@ -294,7 +294,7 @@ async def test_cmc_reasoning( ollama_client = OllamaChatClient() result = await ollama_client.get_response(messages=chat_history) - reasoning = "".join(c.text for c in result.messages.pop().contents if c.type == "text_reasoning") + reasoning = "".join(cast("str", c.text) for c in result.messages.pop().contents if c.type == "text_reasoning") assert reasoning == "test" @@ -349,7 +349,7 @@ async def test_cmc_streaming_reasoning( result = ollama_client.get_response(messages=chat_history, stream=True) async for chunk in result: - reasoning = "".join(c.text for c in chunk.contents if c.type == "text_reasoning") + reasoning = "".join(cast("str", c.text) for c in chunk.contents if c.type == "text_reasoning") assert reasoning == "test" @@ -473,7 +473,7 @@ async def test_cmc_with_invalid_data_content_media_type( ) ollama_client = OllamaChatClient() - ollama_client.client.chat = AsyncMock(return_value=mock_streaming_chat_completion_response) + ollama_client.client.chat = AsyncMock(return_value=mock_streaming_chat_completion_response) # type: ignore[method-assign] # ty: ignore[invalid-assignment] await ollama_client.get_response(messages=chat_history) diff --git a/python/packages/openai/agent_framework_openai/_chat_client.py b/python/packages/openai/agent_framework_openai/_chat_client.py index 2d5cda9ee5d..2797d7f13cb 100644 --- a/python/packages/openai/agent_framework_openai/_chat_client.py +++ b/python/packages/openai/agent_framework_openai/_chat_client.py @@ -92,17 +92,17 @@ ) if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -336,7 +336,7 @@ def _annotations_to_output_text(annotations: Sequence[Annotation] | None) -> lis # region ResponsesClient -class RawOpenAIChatClient( # type: ignore[misc] +class RawOpenAIChatClient( BaseChatClient[OpenAIChatOptionsT], Generic[OpenAIChatOptionsT], ): @@ -356,7 +356,7 @@ class RawOpenAIChatClient( # type: ignore[misc] """ INJECTABLE: ClassVar[set[str]] = {"client"} - STORES_BY_DEFAULT: ClassVar[bool] = True # type: ignore[reportIncompatibleVariableOverride, misc] + STORES_BY_DEFAULT: ClassVar[bool] = True SUPPORTS_RICH_FUNCTION_OUTPUT: ClassVar[bool] = True # Azure OpenAI Responses API may include this header in responses naming the actual model that @@ -616,7 +616,7 @@ def _inner_get_response( stream: bool = False, **kwargs: Any, ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: - continuation_token: OpenAIContinuationToken | None = options.get("continuation_token") # type: ignore[assignment] + continuation_token: OpenAIContinuationToken | None = options.get("continuation_token") if stream: function_call_ids: dict[int, tuple[str, str]] = {} @@ -735,9 +735,9 @@ async def _get_response() -> ChatResponse: client, run_options, validated_options = await self._prepare_request(messages, options) try: if "text_format" in run_options: - raw_response = await client.responses.with_raw_response.parse(stream=False, **run_options) # type: ignore + raw_response = await client.responses.with_raw_response.parse(stream=False, **run_options) else: - raw_response = await client.responses.with_raw_response.create(stream=False, **run_options) # type: ignore + raw_response = await client.responses.with_raw_response.create(stream=False, **run_options) response = raw_response.parse() except Exception as ex: self._handle_request_error(ex) @@ -898,7 +898,7 @@ def _prepare_tools_for_openai( response_tools.append( FunctionShellTool( type="shell", - environment=shell_env, # type: ignore[typeddict-item] + environment=shell_env, ) ) continue @@ -1080,7 +1080,7 @@ def get_image_generation_tool( if output_format: tool["output_format"] = output_format if model: - tool["model"] = model # type: ignore + tool["model"] = model if quality: tool["quality"] = quality if partial_images is not None: @@ -1746,7 +1746,7 @@ def _prepare_content_for_openai( case "function_approval_request": return { "type": "mcp_approval_request", - "id": content.id, # type: ignore[union-attr] + "id": content.id, "arguments": content.function_call.arguments, # type: ignore[union-attr] "name": content.function_call.name, # type: ignore[union-attr] "server_label": content.function_call.additional_properties.get("server_label") # type: ignore[union-attr] @@ -1882,7 +1882,7 @@ def _stringify_mcp_output(output: Any) -> str: if isinstance(output, Sequence) and not isinstance(output, (str, bytes, bytearray)): # cast is for pyright (reportUnknownVariableType); mypy considers # it redundant after the isinstance narrowing. - entries = cast(Sequence[Any], output) # type: ignore[redundant-cast] + entries = cast(Sequence[Any], output) parts: list[str] = [] for entry in entries: if isinstance(entry, str): @@ -2027,7 +2027,7 @@ def _parse_response_from_openai( case "output_text": text_content = Content.from_text( text=message_content.text, - raw_representation=message_content, # type: ignore[reportUnknownArgumentType] + raw_representation=message_content, ) metadata.update(self._get_metadata_from_response(message_content)) if message_content.annotations: @@ -2127,7 +2127,7 @@ def _parse_response_from_openai( Content.from_text_reasoning( id=item.id, text=summary.text, - raw_representation=summary, # type: ignore[arg-type] + raw_representation=summary, ) ) added_reasoning = True @@ -2980,9 +2980,9 @@ def _parse_usage_from_openai(self, usage: ResponseUsage) -> UsageDetails | None: total_token_count=usage.total_tokens, ) if usage.input_tokens_details and usage.input_tokens_details.cached_tokens: - details["openai.cached_input_tokens"] = usage.input_tokens_details.cached_tokens # type: ignore[typeddict-unknown-key] + details["openai.cached_input_tokens"] = usage.input_tokens_details.cached_tokens if usage.output_tokens_details and usage.output_tokens_details.reasoning_tokens: - details["openai.reasoning_tokens"] = usage.output_tokens_details.reasoning_tokens # type: ignore[typeddict-unknown-key] + details["openai.reasoning_tokens"] = usage.output_tokens_details.reasoning_tokens return details def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: @@ -2994,7 +2994,7 @@ def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: return {} -class OpenAIChatClient( # type: ignore[misc] +class OpenAIChatClient( FunctionInvocationLayer[OpenAIChatOptionsT], ChatMiddlewareLayer[OpenAIChatOptionsT], ChatTelemetryLayer[OpenAIChatOptionsT], @@ -3003,7 +3003,7 @@ class OpenAIChatClient( # type: ignore[misc] ): """OpenAI Responses client class with middleware, telemetry, and function invocation support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "openai" @overload def __init__( diff --git a/python/packages/openai/agent_framework_openai/_chat_completion_client.py b/python/packages/openai/agent_framework_openai/_chat_completion_client.py index 0fd14aa2ef5..e66ed47ea65 100644 --- a/python/packages/openai/agent_framework_openai/_chat_completion_client.py +++ b/python/packages/openai/agent_framework_openai/_chat_completion_client.py @@ -65,17 +65,17 @@ ) if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -141,7 +141,7 @@ class OpenAIChatCompletionOptions(ChatOptions[ResponseModelT], Generic[ResponseM """ # OpenAI-specific generation parameters (supported by all models) - logit_bias: dict[str | int, float] # type: ignore[misc] + logit_bias: dict[str | int, float] logprobs: bool top_logprobs: int prediction: Prediction @@ -164,7 +164,7 @@ class OpenAIChatCompletionOptions(ChatOptions[ResponseModelT], Generic[ResponseM # region Base Client -class RawOpenAIChatCompletionClient( # type: ignore[misc] +class RawOpenAIChatCompletionClient( BaseChatClient[OpenAIChatCompletionOptionsT], Generic[OpenAIChatCompletionOptionsT], ): @@ -488,9 +488,9 @@ def get_response( """Get a response from the raw OpenAI chat client.""" super_get_response = cast( "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", - super().get_response, # type: ignore[misc] + super().get_response, ) - return super_get_response( # type: ignore[no-any-return] + return super_get_response( messages=messages, stream=stream, options=options, @@ -762,18 +762,18 @@ def _parse_usage_from_openai(self, usage: CompletionUsage) -> UsageDetails: ) if usage.completion_tokens_details: if tokens := usage.completion_tokens_details.accepted_prediction_tokens: - details["completion/accepted_prediction_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["completion/accepted_prediction_tokens"] = tokens if tokens := usage.completion_tokens_details.audio_tokens: - details["completion/audio_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["completion/audio_tokens"] = tokens if tokens := usage.completion_tokens_details.reasoning_tokens: - details["completion/reasoning_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["completion/reasoning_tokens"] = tokens if tokens := usage.completion_tokens_details.rejected_prediction_tokens: - details["completion/rejected_prediction_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["completion/rejected_prediction_tokens"] = tokens if usage.prompt_tokens_details: if tokens := usage.prompt_tokens_details.audio_tokens: - details["prompt/audio_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["prompt/audio_tokens"] = tokens if tokens := usage.prompt_tokens_details.cached_tokens: - details["prompt/cached_tokens"] = tokens # type: ignore[typeddict-unknown-key] + details["prompt/cached_tokens"] = tokens return details def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | None: @@ -887,7 +887,7 @@ def _prepare_message_for_openai(self, message: Message) -> list[dict[str, Any]]: # If the last message already has tool calls, append to it all_messages[-1]["tool_calls"].append(self._prepare_content_for_openai(content)) else: - args["tool_calls"] = [self._prepare_content_for_openai(content)] # type: ignore + args["tool_calls"] = [self._prepare_content_for_openai(content)] case "function_result": args["tool_call_id"] = content.call_id if content.items: @@ -1032,7 +1032,7 @@ def service_url(self) -> str: # region Public client -class OpenAIChatCompletionClient( # type: ignore[misc] +class OpenAIChatCompletionClient( FunctionInvocationLayer[OpenAIChatCompletionOptionsT], ChatMiddlewareLayer[OpenAIChatCompletionOptionsT], ChatTelemetryLayer[OpenAIChatCompletionOptionsT], @@ -1041,7 +1041,7 @@ class OpenAIChatCompletionClient( # type: ignore[misc] ): """OpenAI Chat completion class with middleware, telemetry, and function invocation support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "openai" @overload def __init__( @@ -1297,12 +1297,12 @@ def get_response( """Get a response from the OpenAI chat client with all standard layers enabled.""" super_get_response = cast( "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", - super().get_response, # type: ignore[misc] + super().get_response, ) effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} if middleware is not None: effective_client_kwargs["middleware"] = middleware - return super_get_response( # type: ignore[no-any-return] + return super_get_response( messages=messages, stream=stream, options=options, diff --git a/python/packages/openai/agent_framework_openai/_embedding_client.py b/python/packages/openai/agent_framework_openai/_embedding_client.py index b11304dd27a..b847eb0fd15 100644 --- a/python/packages/openai/agent_framework_openai/_embedding_client.py +++ b/python/packages/openai/agent_framework_openai/_embedding_client.py @@ -18,9 +18,9 @@ from ._shared import AzureTokenProvider, load_openai_service_settings if sys.version_info >= (3, 13): - from typing import TypeVar # type: ignore # pragma: no cover + from typing import TypeVar # pragma: no cover else: - from typing_extensions import TypeVar # type: ignore # pragma: no cover + from typing_extensions import TypeVar # pragma: no cover if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -268,7 +268,7 @@ async def get_embeddings( ValueError: If model is not provided or values is empty. """ if not values: - return GeneratedEmbeddings([], options=options) # type: ignore + return GeneratedEmbeddings([], options=options) opts: dict[str, Any] = options or {} # type: ignore model = opts.get("model") or self.model @@ -283,7 +283,7 @@ async def get_embeddings( if user := opts.get("user"): kwargs["user"] = user - response = await self.client.embeddings.create(**kwargs) # type: ignore[union-attr] + response = await self.client.embeddings.create(**kwargs) encoding = kwargs.get("encoding_format", "float") embeddings: list[Embedding[list[float]]] = [] @@ -294,7 +294,7 @@ async def get_embeddings( raw = base64.b64decode(item.embedding) vector = list(struct.unpack(f"<{len(raw) // 4}f", raw)) else: - vector = item.embedding # type: ignore[assignment] + vector = item.embedding embeddings.append( Embedding( vector=vector, @@ -320,7 +320,7 @@ class OpenAIEmbeddingClient( ): """OpenAI embedding client with telemetry support.""" - OTEL_PROVIDER_NAME: ClassVar[str] = "openai" # type: ignore[reportIncompatibleVariableOverride, misc] + OTEL_PROVIDER_NAME: ClassVar[str] = "openai" @overload def __init__( diff --git a/python/packages/openai/agent_framework_openai/_shared.py b/python/packages/openai/agent_framework_openai/_shared.py index 894ee3b6122..1cf0d487f67 100644 --- a/python/packages/openai/agent_framework_openai/_shared.py +++ b/python/packages/openai/agent_framework_openai/_shared.py @@ -19,9 +19,9 @@ from openai.types.responses.response_stream_event import ResponseStreamEvent if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover + from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -349,7 +349,7 @@ def _resolve_azure_credential_to_token_provider( if isinstance(credential, AsyncTokenCredential): return get_async_bearer_token_provider(credential, AZURE_OPENAI_TOKEN_SCOPE) if isinstance(credential, TokenCredential): - return get_bearer_token_provider(credential, AZURE_OPENAI_TOKEN_SCOPE) # type: ignore[arg-type] + return get_bearer_token_provider(credential, AZURE_OPENAI_TOKEN_SCOPE) raise ValueError( "The 'credential' parameter must be an Azure TokenCredential, AsyncTokenCredential, or a " "callable token provider." diff --git a/python/packages/openai/tests/openai/conftest.py b/python/packages/openai/tests/openai/conftest.py index 34c81952e37..dc95fa9e5b5 100644 --- a/python/packages/openai/tests/openai/conftest.py +++ b/python/packages/openai/tests/openai/conftest.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import Generator -from typing import Any +from typing import Any, cast from unittest.mock import patch from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter @@ -180,7 +180,7 @@ def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_da if enable_instrumentation or enable_sensitive_data: from opentelemetry.sdk.trace import TracerProvider - tracer_provider = TracerProvider(resource=observability_settings._resource) + tracer_provider = TracerProvider(resource=cast(Any, observability_settings)._resource) trace.set_tracer_provider(tracer_provider) monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False) # type: ignore @@ -191,11 +191,11 @@ def span_exporter(monkeypatch, enable_instrumentation: bool, enable_sensitive_da ): exporter = InMemorySpanExporter() if enable_instrumentation or enable_sensitive_data: - tracer_provider = trace.get_tracer_provider() - if not hasattr(tracer_provider, "add_span_processor"): + current_tracer_provider = trace.get_tracer_provider() + if not hasattr(current_tracer_provider, "add_span_processor"): raise RuntimeError("Tracer provider does not support adding span processors.") - tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) # type: ignore + current_tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) # type: ignore yield exporter exporter.clear() diff --git a/python/packages/openai/tests/openai/test_openai_chat_client.py b/python/packages/openai/tests/openai/test_openai_chat_client.py index 9bc598d3cbe..83378f7e9a2 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_client.py +++ b/python/packages/openai/tests/openai/test_openai_chat_client.py @@ -4,9 +4,10 @@ import inspect import json import os +from collections.abc import Iterator, Sequence from datetime import datetime, timezone from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -18,6 +19,7 @@ Content, FunctionTool, Message, + ResponseStream, SupportsChatGetResponse, SupportsCodeInterpreterTool, SupportsFileSearchTool, @@ -72,9 +74,9 @@ class OutputStruct(BaseModel): class _FakeAsyncEventStream: - def __init__(self, events: list[object], headers: dict[str, str] | None = None) -> None: + def __init__(self, events: Sequence[object], headers: dict[str, str] | None = None) -> None: self._events = events - self._iterator = iter(()) + self._iterator: Iterator[object] = iter(()) self._headers = headers or {} def __aiter__(self) -> "_FakeAsyncEventStream": @@ -111,6 +113,16 @@ async def __aexit__( return None +def _as_chat_response_stream( + stream: Any, +) -> ResponseStream[ChatResponseUpdate, ChatResponse[None]]: + return cast("ResponseStream[ChatResponseUpdate, ChatResponse[None]]", stream) + + +def _response_id_from_token(token: Any) -> str: + return _response_id_from_token(token) + + def _as_raw(mock_response: MagicMock, *, headers: dict[str, str] | None = None) -> MagicMock: """Make ``mock_response`` look like an OpenAI ``with_raw_response`` wrapper. @@ -215,18 +227,18 @@ def test_raw_openai_chat_client_accepts_preconfigured_client_with_timeout() -> N def test_openai_chat_client_supports_all_tool_protocols() -> None: - assert isinstance(OpenAIChatClient, SupportsCodeInterpreterTool) - assert isinstance(OpenAIChatClient, SupportsWebSearchTool) - assert isinstance(OpenAIChatClient, SupportsImageGenerationTool) - assert isinstance(OpenAIChatClient, SupportsMCPTool) - assert isinstance(OpenAIChatClient, SupportsFileSearchTool) + assert isinstance(OpenAIChatClient, SupportsCodeInterpreterTool) # pyrefly: ignore[unsafe-overlap] + assert isinstance(OpenAIChatClient, SupportsWebSearchTool) # pyrefly: ignore[unsafe-overlap] + assert isinstance(OpenAIChatClient, SupportsImageGenerationTool) # pyrefly: ignore[unsafe-overlap] + assert isinstance(OpenAIChatClient, SupportsMCPTool) # pyrefly: ignore[unsafe-overlap] + assert isinstance(OpenAIChatClient, SupportsFileSearchTool) # pyrefly: ignore[unsafe-overlap] def test_protocol_isinstance_with_openai_chat_client_instance() -> None: client = object.__new__(OpenAIChatClient) - assert isinstance(client, SupportsCodeInterpreterTool) - assert isinstance(client, SupportsWebSearchTool) + assert isinstance(client, SupportsCodeInterpreterTool) # pyrefly: ignore[unsafe-overlap] + assert isinstance(client, SupportsWebSearchTool) # pyrefly: ignore[unsafe-overlap] def test_openai_chat_client_tool_methods_return_dict() -> None: @@ -752,7 +764,9 @@ async def test_served_model_header_propagated_to_streaming_updates() -> None: patch.object(client.client.responses, "create", new=AsyncMock(return_value=fake_stream)), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + stream = _as_chat_response_stream( + client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + ) updates = [update async for update in stream] assert updates, "Expected at least one streaming update" @@ -783,7 +797,9 @@ async def test_served_model_header_aggregates_into_final_streaming_response() -> patch.object(client.client.responses, "create", new=AsyncMock(return_value=fake_stream)), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + stream = _as_chat_response_stream( + client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + ) updates = [update async for update in stream] final = ChatResponse.from_updates(updates) @@ -813,7 +829,9 @@ async def test_served_model_header_absent_in_streaming_updates() -> None: patch.object(client.client.responses, "create", new=AsyncMock(return_value=fake_stream)), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + stream = _as_chat_response_stream( + client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + ) updates = [update async for update in stream] assert updates, "Expected at least one streaming update" @@ -852,7 +870,9 @@ async def test_served_model_header_not_captured_for_streaming_text_format() -> N patch.object(client.client.responses, "stream", return_value=fake_stream_ctx), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + stream = _as_chat_response_stream( + client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + ) updates = [update async for update in stream] assert updates, "Expected at least one streaming update" @@ -897,11 +917,11 @@ class _StreamWrapperWithoutHeaders: falls through to the default — matching the real instrumentor's class layout. """ - def __init__(self, events: list[object]) -> None: + def __init__(self, events: Sequence[object]) -> None: self._events = events - self._iterator = iter(()) + self._iterator: Iterator[object] = iter(()) - def __aiter__(self) -> "_StreamWrapperWithoutHeaders": + def __aiter__(self) -> Any: self._iterator = iter(self._events) return self @@ -911,10 +931,10 @@ async def __anext__(self) -> object: except StopIteration as exc: raise StopAsyncIteration from exc - def parse(self) -> "_StreamWrapperWithoutHeaders": + def parse(self) -> Any: return self - async def __aenter__(self) -> "_StreamWrapperWithoutHeaders": + async def __aenter__(self) -> Any: return self async def __aexit__( @@ -934,7 +954,9 @@ async def __aexit__( patch.object(client.client.responses, "create", new=AsyncMock(return_value=headerless_stream)), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + stream = _as_chat_response_stream( + client._inner_get_response(messages=[Message(role="user", contents=["Hi"])], options={}, stream=True) + ) updates = [update async for update in stream] assert updates, "Expected the stream to complete even when the wrapper lacks .headers" @@ -1950,7 +1972,7 @@ def test_parse_response_from_openai_with_mcp_server_tool_result() -> None: def test_parse_chunk_from_openai_with_web_search_call_added() -> None: """Test that response.output_item.added for web_search_call emits search tool call content.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -1978,7 +2000,7 @@ def test_parse_chunk_from_openai_with_web_search_call_added() -> None: def test_parse_chunk_from_openai_with_file_search_call_done() -> None: """Test that response.output_item.done for file_search_call emits search tool result content.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -2016,7 +2038,7 @@ def test_parse_chunk_from_openai_with_file_search_call_done() -> None: def test_parse_chunk_from_openai_ignores_search_progress_events(event_type: str) -> None: """Search progress events should be explicitly ignored instead of logged as unparsed.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -2533,7 +2555,7 @@ def test_streamed_file_citation_coalesces_onto_surrounding_text() -> None: } client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} update1 = client._parse_chunk_from_openai(text_event, chat_options, function_call_ids) @@ -2560,7 +2582,7 @@ def test_streamed_file_citation_roundtrips_as_assistant_history() -> None: rejected by the Responses API. """ client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} text_event = MagicMock() @@ -2898,7 +2920,7 @@ def test_parse_response_uses_response_id_when_no_conversation() -> None: def test_streaming_chunk_with_usage_only() -> None: """Test streaming chunk that only contains usage info.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -2920,6 +2942,7 @@ def test_streaming_chunk_with_usage_only() -> None: # Should have usage content assert len(update.contents) == 1 assert update.contents[0].type == "usage" + assert update.contents[0].usage_details is not None assert update.contents[0].usage_details["total_token_count"] == 75 @@ -3022,6 +3045,7 @@ def test_parse_response_from_openai_with_mcp_approval_request() -> None: assert response.messages[0].contents[0].type == "function_approval_request" req = response.messages[0].contents[0] + assert req.function_call is not None assert req.id == "approval-1" assert req.function_call.name == "do_sensitive_action" assert req.function_call.arguments == {"arg": 1} @@ -3158,7 +3182,7 @@ def test_prepare_tools_for_openai_with_image_generation_options() -> None: # Use static method to create image generation tool tool = OpenAIChatClient.get_image_generation_tool( output_format="png", - size="512x512", + size="1024x1024", quality="high", ) @@ -3167,7 +3191,7 @@ def test_prepare_tools_for_openai_with_image_generation_options() -> None: image_tool = resp_tools[0] assert image_tool["type"] == "image_generation" assert image_tool["output_format"] == "png" - assert image_tool["size"] == "512x512" + assert image_tool["size"] == "1024x1024" assert image_tool["quality"] == "high" @@ -3187,7 +3211,7 @@ def test_prepare_tools_for_openai_with_custom_image_generation_model() -> None: def test_parse_chunk_from_openai_with_mcp_approval_request() -> None: """Test that a streaming mcp_approval_request event is parsed into FunctionApprovalRequestContent.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3203,6 +3227,7 @@ def test_parse_chunk_from_openai_with_mcp_approval_request() -> None: update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert any(c.type == "function_approval_request" for c in update.contents) fa = next(c for c in update.contents if c.type == "function_approval_request") + assert fa.function_call is not None assert fa.id == "approval-stream-1" assert fa.function_call.name == "do_stream_action" @@ -3256,6 +3281,7 @@ async def test_end_to_end_mcp_approval_flow(span_exporter) -> None: response = await client.get_response(messages=[Message(role="user", contents=["Trigger approval"])]) assert response.messages[0].contents[0].type == "function_approval_request" req = response.messages[0].contents[0] + assert req.function_call is not None assert req.id == "approval-1" # Build a user approval and send it (include required function_call) @@ -3299,8 +3325,9 @@ def test_usage_details_with_cached_tokens() -> None: details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None + details_dict = cast("dict[str, Any]", details) assert details["input_token_count"] == 200 - assert details["openai.cached_input_tokens"] == 25 + assert details_dict["openai.cached_input_tokens"] == 25 def test_usage_details_with_reasoning_tokens() -> None: @@ -3317,8 +3344,9 @@ def test_usage_details_with_reasoning_tokens() -> None: details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None + details_dict = cast("dict[str, Any]", details) assert details["output_token_count"] == 80 - assert details["openai.reasoning_tokens"] == 30 + assert details_dict["openai.reasoning_tokens"] == 30 def test_get_metadata_from_response() -> None: @@ -3344,7 +3372,7 @@ def test_get_metadata_from_response() -> None: def test_streaming_response_basic_structure() -> None: """Test that _parse_chunk_from_openai returns proper structure.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions(store=True) + chat_options: dict[str, Any] = {"store": True} function_call_ids: dict[int, tuple[str, str]] = {} # Test with a basic mock event to ensure the method returns proper structure @@ -3363,7 +3391,7 @@ def test_streaming_response_basic_structure() -> None: def test_streaming_response_created_type() -> None: """Test streaming response with created type""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3382,7 +3410,7 @@ def test_streaming_response_created_type() -> None: def test_streaming_response_in_progress_type() -> None: """Test streaming response with in_progress type""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3401,7 +3429,7 @@ def test_streaming_response_in_progress_type() -> None: def test_streaming_annotation_added_with_file_path() -> None: """Streaming `file_path` should attach as a text annotation, matching non-streaming.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3435,7 +3463,7 @@ def test_streaming_annotation_added_with_file_citation() -> None: Annotations on text content roundtrip cleanly. """ client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3466,7 +3494,7 @@ def test_streaming_annotation_added_with_file_citation() -> None: def test_streaming_annotation_added_with_container_file_citation() -> None: """Streaming `container_file_citation` should attach as a text annotation.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3503,7 +3531,7 @@ def test_streaming_annotation_added_with_container_file_citation() -> None: def test_streaming_annotation_added_with_url_citation() -> None: """Test streaming annotation added event with url_citation type produces citation annotation.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3544,7 +3572,7 @@ def test_streaming_annotation_added_with_url_citation() -> None: def test_streaming_annotation_added_with_url_citation_no_url() -> None: """Test streaming annotation added event with url_citation but missing url is ignored.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3563,7 +3591,7 @@ def test_streaming_annotation_added_with_url_citation_no_url() -> None: def test_streaming_annotation_added_with_url_citation_no_indices() -> None: """Test streaming annotation with url_citation that has url but no start_index/end_index.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3578,6 +3606,7 @@ def test_streaming_annotation_added_with_url_citation_no_indices() -> None: response = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert len(response.contents) == 1 + assert response.contents[0].annotations is not None annotation = response.contents[0].annotations[0] assert annotation["type"] == "citation" assert annotation["title"] == "Example" @@ -3589,7 +3618,7 @@ def test_streaming_annotation_added_with_url_citation_no_indices() -> None: def test_streaming_annotation_added_with_unknown_type() -> None: """Test streaming annotation added event with unknown type is ignored.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event = MagicMock() @@ -3682,7 +3711,7 @@ async def test_inner_get_response_streaming_with_response_format_tracks_reasonin patch.object(client.client.responses, "stream", return_value=_FakeAsyncEventStreamContext(events)), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=messages, options={}, stream=True) + stream = _as_chat_response_stream(client._inner_get_response(messages=messages, options={}, stream=True)) updates = [update async for update in stream] reasoning_chunks = [ @@ -3801,7 +3830,7 @@ def test_prepare_content_for_openai_function_result_without_items() -> None: def test_parse_chunk_from_openai_code_interpreter() -> None: """Test _parse_chunk_from_openai with code_interpreter_call.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event_image = MagicMock() @@ -3825,7 +3854,7 @@ def test_parse_chunk_from_openai_code_interpreter() -> None: def test_parse_chunk_from_openai_code_interpreter_delta() -> None: """Test _parse_chunk_from_openai with code_interpreter_call_code delta events.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} # Test delta event @@ -3854,7 +3883,7 @@ def test_parse_chunk_from_openai_code_interpreter_delta() -> None: def test_parse_chunk_from_openai_code_interpreter_done() -> None: """Test _parse_chunk_from_openai with code_interpreter_call_code done event.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} # Test done event @@ -3873,6 +3902,7 @@ def test_parse_chunk_from_openai_code_interpreter_done() -> None: assert result.contents[0].call_id == "ci_456" assert result.contents[0].inputs assert result.contents[0].inputs[0].type == "text" + assert result.contents[0].inputs[0].text is not None assert "import pandas as pd" in result.contents[0].inputs[0].text # Verify additional_properties for stream ordering assert result.contents[0].additional_properties["output_index"] == 0 @@ -3883,7 +3913,7 @@ def test_parse_chunk_from_openai_code_interpreter_done() -> None: def test_parse_chunk_from_openai_reasoning() -> None: """Test _parse_chunk_from_openai with reasoning content.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} mock_event_reasoning = MagicMock() @@ -3931,7 +3961,7 @@ def test_prepare_content_for_openai_text_reasoning_comprehensive() -> None: def test_streaming_reasoning_text_delta_event() -> None: """Test reasoning text delta event creates TextReasoningContent.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} event = ResponseReasoningTextDeltaEvent( @@ -3957,7 +3987,7 @@ def test_streaming_reasoning_text_delta_event() -> None: def test_streaming_reasoning_text_done_event_skipped_after_deltas() -> None: """Test reasoning text done event does not emit content when deltas were already received.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} seen_reasoning_delta_item_ids: set[str] = {"reasoning_456"} @@ -3983,7 +4013,7 @@ def test_streaming_reasoning_text_done_event_skipped_after_deltas() -> None: def test_streaming_reasoning_text_done_event_fallback_without_deltas() -> None: """Test reasoning text done event emits content when no deltas were received for this item_id.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} seen_reasoning_delta_item_ids: set[str] = set() @@ -4012,7 +4042,7 @@ def test_streaming_reasoning_text_done_event_fallback_without_deltas() -> None: def test_streaming_reasoning_summary_text_delta_event() -> None: """Test reasoning summary text delta event creates TextReasoningContent.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} event = ResponseReasoningSummaryTextDeltaEvent( @@ -4037,7 +4067,7 @@ def test_streaming_reasoning_summary_text_delta_event() -> None: def test_streaming_reasoning_summary_text_done_event_skipped_after_deltas() -> None: """Test reasoning summary text done event does not emit content when deltas were already received.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} seen_reasoning_delta_item_ids: set[str] = {"summary_012"} @@ -4063,7 +4093,7 @@ def test_streaming_reasoning_summary_text_done_event_skipped_after_deltas() -> N def test_streaming_reasoning_summary_text_done_event_fallback_without_deltas() -> None: """Test reasoning summary text done event emits content when no deltas were received for this item_id.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} seen_reasoning_delta_item_ids: set[str] = set() @@ -4092,7 +4122,7 @@ def test_streaming_reasoning_summary_text_done_event_fallback_without_deltas() - def test_streaming_reasoning_deltas_then_done_no_duplication() -> None: """Sending delta events followed by a done event produces content only from deltas.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} seen_reasoning_delta_item_ids: set[str] = set() item_id = "reasoning_seq" @@ -4124,7 +4154,7 @@ def test_streaming_reasoning_deltas_then_done_no_duplication() -> None: all_contents = [] with patch.object(client, "_get_metadata_from_response", return_value={}): - for event in [delta1, delta2, done]: + for event in cast("tuple[Any, ...]", (delta1, delta2, done)): response = client._parse_chunk_from_openai( event, chat_options, @@ -4136,7 +4166,7 @@ def test_streaming_reasoning_deltas_then_done_no_duplication() -> None: assert len(all_contents) == 2 assert all_contents[0].text == "Hello " assert all_contents[1].text == "world" - assert "".join(c.text for c in all_contents) == "Hello world" + assert "".join(c.text or "" for c in all_contents) == "Hello world" async def test_inner_get_response_streaming_create_tracks_reasoning_delta_ids() -> None: @@ -4168,7 +4198,7 @@ async def test_inner_get_response_streaming_create_tracks_reasoning_delta_ids() patch.object(client.client.responses, "create", new=AsyncMock(return_value=_FakeAsyncEventStream(events))), patch.object(client, "_get_metadata_from_response", return_value={}), ): - stream = client._inner_get_response(messages=messages, options={}, stream=True) + stream = _as_chat_response_stream(client._inner_get_response(messages=messages, options={}, stream=True)) updates = [update async for update in stream] reasoning_chunks = [ @@ -4180,7 +4210,7 @@ async def test_inner_get_response_streaming_create_tracks_reasoning_delta_ids() def test_streaming_reasoning_events_preserve_metadata() -> None: """Test that reasoning events preserve metadata like regular text events.""" client = OpenAIChatClient(model="test-model", api_key="test-key") - chat_options = ChatOptions() + chat_options: dict[str, Any] = {} function_call_ids: dict[int, tuple[str, str]] = {} text_event = ResponseTextDeltaEvent( @@ -4248,7 +4278,9 @@ def test_parse_response_from_openai_image_generation_raw_base64(): assert result_content.type == "image_generation_tool_result" assert result_content.outputs data_out = result_content.outputs + assert isinstance(data_out, Content) assert data_out.type == "data" + assert data_out.uri is not None assert data_out.uri.startswith("data:image/png;base64,") assert data_out.media_type == "image/png" @@ -4285,6 +4317,7 @@ def test_parse_response_from_openai_image_generation_existing_data_uri(): assert result_content.type == "image_generation_tool_result" assert result_content.outputs data_out = result_content.outputs + assert isinstance(data_out, Content) assert data_out.type == "data" assert data_out.uri == f"data:image/webp;base64,{valid_webp_base64}" assert data_out.media_type == "image/webp" @@ -4316,8 +4349,10 @@ def test_parse_response_from_openai_image_generation_format_detection(): result_contents = response_jpeg.messages[0].contents assert result_contents[1].type == "image_generation_tool_result" outputs = result_contents[1].outputs - assert outputs and outputs.type == "data" + assert isinstance(outputs, Content) + assert outputs.type == "data" assert outputs.media_type == "image/jpeg" + assert outputs.uri is not None assert "data:image/jpeg;base64," in outputs.uri # Test WEBP detection @@ -4340,8 +4375,10 @@ def test_parse_response_from_openai_image_generation_format_detection(): with patch.object(client, "_get_metadata_from_response", return_value={}): response_webp = client._parse_response_from_openai(mock_response_webp, options={}) # type: ignore outputs_webp = response_webp.messages[0].contents[1].outputs - assert outputs_webp and outputs_webp.type == "data" + assert isinstance(outputs_webp, Content) + assert outputs_webp.type == "data" assert outputs_webp.media_type == "image/webp" + assert outputs_webp.uri is not None assert "data:image/webp;base64," in outputs_webp.uri @@ -4376,6 +4413,7 @@ def test_parse_response_from_openai_image_generation_fallback(): assert result_content.type == "image_generation_tool_result" assert result_content.outputs content = result_content.outputs + assert isinstance(content, Content) assert content.media_type == "image/png" assert f"data:image/png;base64,{unrecognized_base64}" == content.uri @@ -4394,7 +4432,7 @@ async def test_prepare_options_store_parameter_handling() -> None: options = await client._prepare_options(messages, chat_options) # type: ignore assert options["store"] is False - chat_options = ChatOptions(store=None, conversation_id=None) + chat_options = cast(Any, {"store": None, "conversation_id": None}) options = await client._prepare_options(messages, chat_options) # type: ignore assert "store" not in options assert "previous_response_id" not in options @@ -4779,7 +4817,9 @@ async def test_integration_options( options["tools"] = [get_weather] # Test streaming mode - response = await client.get_response(stream=True, messages=messages, options=options).get_final_response() + response = ( + await cast(Any, client).get_response(stream=True, messages=messages, options=options).get_final_response() + ) assert response is not None assert isinstance(response, ChatResponse) @@ -4828,7 +4868,7 @@ async def test_integration_web_search() -> None: "tools": [web_search_tool_with_location], }, } - response = await client.get_response(stream=True, **content).get_final_response() + response = await cast(Any, client).get_response(stream=True, **content).get_final_response() assert response.text is not None @@ -4845,8 +4885,10 @@ async def test_integration_file_search() -> None: assert isinstance(openai_responses_client, SupportsChatGetResponse) file_id, vector_store = await create_vector_store(openai_responses_client) + vector_store_id = vector_store.vector_store_id + assert vector_store_id is not None # Use static method for file search tool - file_search_tool = OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store.vector_store_id]) + file_search_tool = OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store_id]) # Test that the client will use the file search tool response = await openai_responses_client.get_response( messages=[ @@ -4861,7 +4903,7 @@ async def test_integration_file_search() -> None: }, ) - await delete_vector_store(openai_responses_client, file_id, vector_store.vector_store_id) + await delete_vector_store(openai_responses_client, file_id, vector_store_id) assert "sunny" in response.text.lower() assert "75" in response.text @@ -4879,10 +4921,12 @@ async def test_integration_streaming_file_search() -> None: assert isinstance(openai_responses_client, SupportsChatGetResponse) file_id, vector_store = await create_vector_store(openai_responses_client) + vector_store_id = vector_store.vector_store_id + assert vector_store_id is not None # Use static method for file search tool - file_search_tool = OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store.vector_store_id]) + file_search_tool = OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store_id]) # Test that the client will use the web search tool - response = openai_responses_client.get_streaming_response( + response = cast(Any, openai_responses_client).get_streaming_response( messages=[ Message( role="user", @@ -4904,7 +4948,7 @@ async def test_integration_streaming_file_search() -> None: if content.type == "text" and content.text: full_message += content.text - await delete_vector_store(openai_responses_client, file_id, vector_store.vector_store_id) + await delete_vector_store(openai_responses_client, file_id, vector_store_id) assert "sunny" in full_message.lower() assert "75" in full_message @@ -4934,7 +4978,9 @@ def get_test_image() -> Content: ] options: dict[str, Any] = {"tools": [get_test_image], "tool_choice": "auto"} - response = await client.get_response(messages=messages, stream=True, options=options).get_final_response() + response = ( + await cast(Any, client).get_response(messages=messages, stream=True, options=options).get_final_response() + ) assert response is not None assert isinstance(response, ChatResponse) @@ -4959,7 +5005,7 @@ async def search_hotels(city: Annotated[str, "The city to search for hotels in"] client = OpenAIChatClient(model="gpt-5.4") client.function_invocation_configuration["max_iterations"] = 2 - agent = Agent(client=client, tools=[search_hotels], default_options={"store": False}) + agent = Agent(client=cast(Any, client), tools=[search_hotels], default_options=ChatOptions(store=False)) session = agent.create_session() first_response = await agent.run( @@ -4992,7 +5038,7 @@ def test_continuation_token_json_serializable() -> None: from agent_framework_openai import OpenAIContinuationToken token = OpenAIContinuationToken(response_id="resp_abc123") - assert token["response_id"] == "resp_abc123" + assert _response_id_from_token(token) == "resp_abc123" # JSON round-trip serialized = json.dumps(token) @@ -5011,7 +5057,7 @@ def test_chat_response_with_continuation_token() -> None: continuation_token=token, ) assert response.continuation_token is not None - assert response.continuation_token["response_id"] == "resp_123" + assert _response_id_from_token(response.continuation_token) == "resp_123" def test_chat_response_without_continuation_token() -> None: @@ -5033,7 +5079,7 @@ def test_chat_response_update_with_continuation_token() -> None: continuation_token=token, ) assert update.continuation_token is not None - assert update.continuation_token["response_id"] == "resp_456" + assert _response_id_from_token(update.continuation_token) == "resp_456" def test_agent_response_with_continuation_token() -> None: @@ -5048,7 +5094,7 @@ def test_agent_response_with_continuation_token() -> None: continuation_token=token, ) assert response.continuation_token is not None - assert response.continuation_token["response_id"] == "resp_789" + assert _response_id_from_token(response.continuation_token) == "resp_789" def test_agent_response_update_with_continuation_token() -> None: @@ -5064,7 +5110,7 @@ def test_agent_response_update_with_continuation_token() -> None: continuation_token=token, ) assert update.continuation_token is not None - assert update.continuation_token["response_id"] == "resp_012" + assert _response_id_from_token(update.continuation_token) == "resp_012" def test_parse_response_from_openai_with_background_in_progress() -> None: @@ -5089,7 +5135,7 @@ def test_parse_response_from_openai_with_background_in_progress() -> None: result = client._parse_response_from_openai(mock_response, options=options) assert result.continuation_token is not None - assert result.continuation_token["response_id"] == "resp_bg_123" + assert _response_id_from_token(result.continuation_token) == "resp_bg_123" def test_parse_response_from_openai_with_background_queued() -> None: @@ -5114,7 +5160,7 @@ def test_parse_response_from_openai_with_background_queued() -> None: result = client._parse_response_from_openai(mock_response, options=options) assert result.continuation_token is not None - assert result.continuation_token["response_id"] == "resp_bg_456" + assert _response_id_from_token(result.continuation_token) == "resp_bg_456" def test_parse_response_from_openai_with_background_completed() -> None: @@ -5164,7 +5210,7 @@ def test_streaming_response_in_progress_sets_continuation_token() -> None: update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert update.continuation_token is not None - assert update.continuation_token["response_id"] == "resp_stream_123" + assert _response_id_from_token(update.continuation_token) == "resp_stream_123" def test_streaming_response_created_with_in_progress_status_sets_continuation_token() -> None: @@ -5184,7 +5230,7 @@ def test_streaming_response_created_with_in_progress_status_sets_continuation_to update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) assert update.continuation_token is not None - assert update.continuation_token["response_id"] == "resp_created_123" + assert _response_id_from_token(update.continuation_token) == "resp_created_123" def test_streaming_response_completed_no_continuation_token() -> None: @@ -5234,7 +5280,7 @@ def test_map_chat_to_agent_update_preserves_continuation_token() -> None: """Test that map_chat_to_agent_update propagates continuation_token.""" from agent_framework._types import map_chat_to_agent_update - token = {"response_id": "resp_map_123"} + token = cast(Any, {"response_id": "resp_map_123"}) chat_update = ChatResponseUpdate( contents=[Content.from_text(text="chunk")], role="assistant", @@ -5245,7 +5291,7 @@ def test_map_chat_to_agent_update_preserves_continuation_token() -> None: agent_update = map_chat_to_agent_update(chat_update, agent_name="test-agent") assert agent_update.continuation_token is not None - assert agent_update.continuation_token["response_id"] == "resp_map_123" + assert _response_id_from_token(agent_update.continuation_token) == "resp_map_123" async def test_prepare_options_excludes_continuation_token() -> None: @@ -5439,7 +5485,7 @@ async def test_prepare_messages_for_openai_does_not_replay_fc_id_when_loaded_fro context = SessionContext(session_id=session.session_id, input_messages=[next_turn_input]) await provider.before_run( - agent=None, + agent=cast(Any, None), session=session, context=context, state=session.state.setdefault(provider.source_id, {}), @@ -5459,7 +5505,7 @@ async def test_prepare_messages_for_openai_does_not_replay_fc_id_when_loaded_fro restored = AgentSession.from_dict(json.loads(json.dumps(session.to_dict()))) restored_context = SessionContext(session_id=restored.session_id, input_messages=[next_turn_input]) await provider.before_run( - agent=None, + agent=cast(Any, None), session=restored, context=restored_context, state=restored.state.setdefault(provider.source_id, {}), diff --git a/python/packages/openai/tests/openai/test_openai_chat_client_azure.py b/python/packages/openai/tests/openai/test_openai_chat_client_azure.py index fe25a7202d7..70d9de234e6 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_client_azure.py +++ b/python/packages/openai/tests/openai/test_openai_chat_client_azure.py @@ -6,7 +6,8 @@ import os from functools import wraps from pathlib import Path -from typing import Any +from types import TracebackType +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest @@ -41,7 +42,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") debug_message = f"Azure OpenAI debug: endpoint={endpoint}, model={model}, api_version={api_version}" if hasattr(exc, "add_note"): - exc.add_note(debug_message) + cast(Any, exc).add_note(debug_message) elif exc.args: exc.args = (f"{exc.args[0]}\n{debug_message}", *exc.args[1:]) else: @@ -104,12 +105,13 @@ async def get_weather(location: str) -> str: def test_init_with_azure_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None: - client = OpenAIChatClient(credential=AzureCliCredential()) + client = OpenAIChatClient(credential=cast(Any, AzureCliCredential())) assert client.model == azure_openai_unit_test_env["AZURE_OPENAI_CHAT_MODEL"] assert isinstance(client, SupportsChatGetResponse) assert isinstance(client.client, AsyncAzureOpenAI) assert client.OTEL_PROVIDER_NAME == "azure.ai.openai" + assert client.azure_endpoint is not None assert client.azure_endpoint.startswith(azure_openai_unit_test_env["AZURE_OPENAI_ENDPOINT"]) @@ -197,13 +199,24 @@ class TestAsyncTokenCredential(AsyncTokenCredential): async def get_token(self, *scopes: str, **kwargs: object): raise NotImplementedError + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + pass + monkeypatch.setenv("OPENAI_API_KEY", "test-dummy-key") monkeypatch.setenv("OPENAI_MODEL", "gpt-5") credential = TestAsyncTokenCredential() token_provider = MagicMock() with patch("azure.identity.aio.get_bearer_token_provider", return_value=token_provider) as mock_provider: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) assert isinstance(client.client, AsyncAzureOpenAI) mock_provider.assert_called_once_with(credential, "https://cognitiveservices.azure.com/.default") @@ -211,7 +224,7 @@ async def get_token(self, *scopes: str, **kwargs: object): @pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_API_VERSION"]], indirect=True) def test_init_uses_default_azure_api_version(azure_openai_unit_test_env: dict[str, str]) -> None: - client = OpenAIChatClient(credential=AzureCliCredential()) + client = OpenAIChatClient(credential=cast(Any, AzureCliCredential())) assert client.model == azure_openai_unit_test_env["AZURE_OPENAI_CHAT_MODEL"] assert client.api_version is not None @@ -290,7 +303,7 @@ async def test_integration_options( needs_validation: bool, ) -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) client.function_invocation_configuration["max_iterations"] = 2 for streaming in [False, True]: @@ -309,13 +322,17 @@ async def test_integration_options( options["tools"] = [get_weather] if streaming: - response = await client.get_response( - messages=messages, - stream=True, - options=options, - ).get_final_response() + response = ( + await cast(Any, client) + .get_response( + messages=messages, + stream=True, + options=options, + ) + .get_final_response() + ) else: - response = await client.get_response(messages=messages, options=options) + response = await cast(Any, client).get_response(messages=messages, options=options) assert isinstance(response, ChatResponse) assert response.text is not None @@ -343,7 +360,7 @@ async def test_integration_options( @_with_azure_openai_debug() async def test_integration_web_search() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) response = await client.get_response( messages=[ @@ -367,15 +384,17 @@ async def test_integration_web_search() -> None: @_with_azure_openai_debug() async def test_integration_client_file_search() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) file_id, vector_store = await create_vector_store(client) + vector_store_id = vector_store.vector_store_id + assert vector_store_id is not None try: - response = await client.get_response( + response = await cast(Any, client).get_response( messages=[ Message(role="user", contents=["What is the weather today? Do a file search to find the answer."]) ], options={ - "tools": [OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store.vector_store_id])], + "tools": [OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store_id])], "tool_choice": "auto", }, ) @@ -383,7 +402,7 @@ async def test_integration_client_file_search() -> None: assert "sunny" in response.text.lower() assert "75" in response.text finally: - await delete_vector_store(client, file_id, vector_store.vector_store_id) + await delete_vector_store(client, file_id, vector_store_id) @pytest.mark.flaky @@ -392,16 +411,18 @@ async def test_integration_client_file_search() -> None: @_with_azure_openai_debug() async def test_integration_client_file_search_streaming() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) file_id, vector_store = await create_vector_store(client) + vector_store_id = vector_store.vector_store_id + assert vector_store_id is not None try: - response_stream = client.get_response( + response_stream = cast(Any, client).get_response( messages=[ Message(role="user", contents=["What is the weather today? Do a file search to find the answer."]) ], stream=True, options={ - "tools": [OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store.vector_store_id])], + "tools": [OpenAIChatClient.get_file_search_tool(vector_store_ids=[vector_store_id])], "tool_choice": "auto", }, ) @@ -410,7 +431,7 @@ async def test_integration_client_file_search_streaming() -> None: assert "sunny" in full_response.text.lower() assert "75" in full_response.text finally: - await delete_vector_store(client, file_id, vector_store.vector_store_id) + await delete_vector_store(client, file_id, vector_store_id) @pytest.mark.flaky @@ -419,7 +440,7 @@ async def test_integration_client_file_search_streaming() -> None: @_with_azure_openai_debug() async def test_integration_client_agent_hosted_mcp_tool() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) response = await client.get_response( messages=[Message(role="user", contents=["How to create an Azure storage account using az cli?"])], options={ @@ -443,7 +464,7 @@ async def test_integration_client_agent_hosted_mcp_tool() -> None: @_with_azure_openai_debug() async def test_integration_client_agent_hosted_code_interpreter_tool() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) response = await client.get_response( messages=[Message(role="user", contents=["Calculate the sum of numbers from 1 to 10 using Python code."])], @@ -465,7 +486,7 @@ async def test_integration_client_agent_existing_session() -> None: preserved_session = None async with Agent( - client=OpenAIChatClient(credential=credential), + client=OpenAIChatClient(credential=cast(Any, credential)), instructions="You are a helpful assistant with good memory.", ) as first_agent: session = first_agent.create_session() @@ -480,7 +501,7 @@ async def test_integration_client_agent_existing_session() -> None: if preserved_session: async with Agent( - client=OpenAIChatClient(credential=credential), + client=OpenAIChatClient(credential=cast(Any, credential)), instructions="You are a helpful assistant with good memory.", ) as second_agent: second_response = await second_agent.run( @@ -507,7 +528,7 @@ def get_test_image() -> Content: return Content.from_data(data=image_bytes, media_type="image/jpeg") async with AzureCliCredential() as credential: - client = OpenAIChatClient(credential=credential) + client = OpenAIChatClient(credential=cast(Any, credential)) client.function_invocation_configuration["max_iterations"] = 2 response = await client.get_response( diff --git a/python/packages/openai/tests/openai/test_openai_chat_completion_client.py b/python/packages/openai/tests/openai/test_openai_chat_completion_client.py index 85e12b8626d..29b4a7e3fa7 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_completion_client.py +++ b/python/packages/openai/tests/openai/test_openai_chat_completion_client.py @@ -3,7 +3,7 @@ import inspect import json import os -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest @@ -72,7 +72,7 @@ def test_init_uses_explicit_parameters() -> None: def test_supports_web_search_only() -> None: assert not isinstance(OpenAIChatCompletionClient, SupportsCodeInterpreterTool) - assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool) + assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool) # pyrefly: ignore[unsafe-overlap] assert not isinstance(OpenAIChatCompletionClient, SupportsImageGenerationTool) assert not isinstance(OpenAIChatCompletionClient, SupportsMCPTool) assert not isinstance(OpenAIChatCompletionClient, SupportsFileSearchTool) @@ -459,7 +459,7 @@ def test_function_result_exception_handling(openai_unit_test_env: dict[str, str] Content.from_function_result( call_id="call-123", result="Error: Function failed.", - exception=test_exception, + exception=str(test_exception), ) ], ) @@ -791,7 +791,7 @@ def test_parse_text_reasoning_content_from_response( choices=[ Choice( index=0, - message=ChatCompletionMessage( + message=cast(Any, ChatCompletionMessage)( role="assistant", content="The answer is 42.", reasoning_details=mock_reasoning_details, @@ -843,7 +843,7 @@ def test_parse_text_reasoning_content_from_streaming_chunk( choices=[ ChunkChoice( index=0, - delta=ChunkChoiceDelta( + delta=cast(Any, ChunkChoiceDelta)( role="assistant", content="Partial answer", reasoning_details=mock_reasoning_details, @@ -1181,7 +1181,7 @@ def test_parse_text_with_refusal(openai_unit_test_env: dict[str, str]) -> None: def test_prepare_options_without_model(openai_unit_test_env: dict[str, str]) -> None: """Test that prepare_options raises error when model is not set.""" client = OpenAIChatCompletionClient() - client.model = None # Remove model + cast(Any, client).model = None # Remove model messages = [Message(role="user", contents=["test"])] @@ -1735,11 +1735,15 @@ async def test_integration_options( options["tools"] = [get_weather] # Test streaming mode - response = await client.get_response( - messages=messages, - stream=True, - options=options, - ).get_final_response() + response = ( + await cast(Any, client) + .get_response( + messages=messages, + stream=True, + options=options, + ) + .get_final_response() + ) assert response is not None assert isinstance(response, ChatResponse) @@ -1779,7 +1783,7 @@ async def test_integration_web_search() -> None: for streaming in [False, True]: # Use static method for web search tool web_search_tool = OpenAIChatCompletionClient.get_web_search_tool() - content = { + weather_content: dict[str, Any] = { "messages": [ Message( role="user", @@ -1792,9 +1796,9 @@ async def test_integration_web_search() -> None: }, } if streaming: - response = await client.get_response(stream=True, **content).get_final_response() + response = await client.get_response(stream=True, **weather_content).get_final_response() else: - response = await client.get_response(**content) + response = await client.get_response(**weather_content) assert response is not None assert isinstance(response, ChatResponse) @@ -1811,7 +1815,7 @@ async def test_integration_web_search() -> None: }, } ) - content = { + content: dict[str, Any] = { "messages": [ Message( role="user", @@ -1856,7 +1860,7 @@ def test_streaming_chunk_with_null_delta_is_skipped( choices=[ Choice.model_construct( index=0, - delta=None, + delta=None, # type: ignore[arg-type] finish_reason="stop", ) ], @@ -1889,7 +1893,7 @@ def test_streaming_chunk_with_null_delta_preserves_finish_reason( choices=[ Choice.model_construct( index=0, - delta=None, + delta=None, # type: ignore[arg-type] finish_reason="length", ) ], @@ -1951,7 +1955,7 @@ def test_streaming_chunk_with_null_delta_and_usage( choices=[ Choice.model_construct( index=0, - delta=None, + delta=None, # type: ignore[arg-type] finish_reason="stop", ) ], @@ -1982,7 +1986,7 @@ def test_streaming_chunk_with_null_delta_no_tool_calls_parsed( choices=[ Choice.model_construct( index=0, - delta=None, + delta=None, # type: ignore[arg-type] finish_reason="tool_calls", ) ], diff --git a/python/packages/openai/tests/openai/test_openai_chat_completion_client_azure.py b/python/packages/openai/tests/openai/test_openai_chat_completion_client_azure.py index b650a3bd509..b92566f36f2 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_completion_client_azure.py +++ b/python/packages/openai/tests/openai/test_openai_chat_completion_client_azure.py @@ -4,7 +4,8 @@ import os from functools import wraps -from typing import Any +from types import TracebackType +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest @@ -46,7 +47,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") debug_message = f"Azure OpenAI debug: endpoint={endpoint}, model={model}, api_version={api_version}" if hasattr(exc, "add_note"): - exc.add_note(debug_message) + cast(Any, exc).add_note(debug_message) elif exc.args: exc.args = (f"{exc.args[0]}\n{debug_message}", *exc.args[1:]) else: @@ -159,13 +160,24 @@ class TestAsyncTokenCredential(AsyncTokenCredential): async def get_token(self, *scopes: str, **kwargs: object): raise NotImplementedError + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + pass + monkeypatch.setenv("OPENAI_API_KEY", "test-dummy-key") monkeypatch.setenv("OPENAI_MODEL", "gpt-5") credential = TestAsyncTokenCredential() token_provider = MagicMock() with patch("azure.identity.aio.get_bearer_token_provider", return_value=token_provider) as mock_provider: - client = OpenAIChatCompletionClient(credential=credential) + client = OpenAIChatCompletionClient(credential=cast(Any, credential)) assert isinstance(client.client, AsyncAzureOpenAI) mock_provider.assert_called_once_with(credential, "https://cognitiveservices.azure.com/.default") @@ -189,7 +201,7 @@ def test_openai_base_url_wins_over_azure_aliases(monkeypatch, azure_openai_unit_ @_with_azure_openai_debug() async def test_azure_openai_chat_completion_client_response() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatCompletionClient(credential=credential) + client = OpenAIChatCompletionClient(credential=cast(Any, credential)) assert isinstance(client, SupportsChatGetResponse) messages = [ @@ -222,7 +234,7 @@ async def test_azure_openai_chat_completion_client_response() -> None: @_with_azure_openai_debug() async def test_azure_openai_chat_completion_client_response_tools() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatCompletionClient(credential=credential) + client = OpenAIChatCompletionClient(credential=cast(Any, credential)) response = await client.get_response( messages=[Message(role="user", contents=["who are Emily and David?"])], @@ -240,7 +252,7 @@ async def test_azure_openai_chat_completion_client_response_tools() -> None: @_with_azure_openai_debug() async def test_azure_openai_chat_completion_client_streaming() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatCompletionClient(credential=credential) + client = OpenAIChatCompletionClient(credential=cast(Any, credential)) response = client.get_response( messages=[ @@ -278,7 +290,7 @@ async def test_azure_openai_chat_completion_client_streaming() -> None: @_with_azure_openai_debug() async def test_azure_openai_chat_completion_client_streaming_tools() -> None: async with AzureCliCredential() as credential: - client = OpenAIChatCompletionClient(credential=credential) + client = OpenAIChatCompletionClient(credential=cast(Any, credential)) response = client.get_response( messages=[Message(role="user", contents=["who are Emily and David?"])], @@ -304,7 +316,7 @@ async def test_azure_openai_chat_completion_client_agent_basic_run() -> None: async with ( AzureCliCredential() as credential, Agent( - client=OpenAIChatCompletionClient(credential=credential), + client=OpenAIChatCompletionClient(credential=cast(Any, credential)), ) as agent, ): response = await agent.run("Please respond with exactly: 'This is a response test.'") @@ -321,7 +333,7 @@ async def test_azure_openai_chat_completion_client_agent_basic_run() -> None: async def test_azure_openai_chat_completion_client_agent_basic_run_streaming() -> None: async with ( AzureCliCredential() as credential, - Agent(client=OpenAIChatCompletionClient(credential=credential)) as agent, + Agent(client=OpenAIChatCompletionClient(credential=cast(Any, credential))) as agent, ): full_text = "" async for chunk in agent.run("Please respond with exactly: 'This is a streaming response test.'", stream=True): @@ -340,7 +352,7 @@ async def test_azure_openai_chat_completion_client_agent_session_persistence() - async with ( AzureCliCredential() as credential, Agent( - client=OpenAIChatCompletionClient(credential=credential), + client=OpenAIChatCompletionClient(credential=cast(Any, credential)), instructions="You are a helpful assistant with good memory.", ) as agent, ): @@ -363,7 +375,7 @@ async def test_azure_openai_chat_completion_client_agent_existing_session() -> N preserved_session = None async with Agent( - client=OpenAIChatCompletionClient(credential=credential), + client=OpenAIChatCompletionClient(credential=cast(Any, credential)), instructions="You are a helpful assistant with good memory.", ) as first_agent: session = first_agent.create_session() @@ -374,7 +386,7 @@ async def test_azure_openai_chat_completion_client_agent_existing_session() -> N if preserved_session: async with Agent( - client=OpenAIChatCompletionClient(credential=credential), + client=OpenAIChatCompletionClient(credential=cast(Any, credential)), instructions="You are a helpful assistant with good memory.", ) as second_agent: second_response = await second_agent.run("What is my name?", session=preserved_session) @@ -392,7 +404,7 @@ async def test_azure_chat_completion_client_agent_level_tool_persistence() -> No async with ( AzureCliCredential() as credential, Agent( - client=OpenAIChatCompletionClient(credential=credential), + client=OpenAIChatCompletionClient(credential=cast(Any, credential)), instructions="You are a helpful assistant that uses available tools.", tools=[get_weather], ) as agent, diff --git a/python/packages/openai/tests/openai/test_openai_chat_completion_client_base.py b/python/packages/openai/tests/openai/test_openai_chat_completion_client_base.py index 1f25762105a..8bce1622ea7 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_completion_client_base.py +++ b/python/packages/openai/tests/openai/test_openai_chat_completion_client_base.py @@ -2,6 +2,7 @@ from copy import deepcopy from datetime import datetime, timezone +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -197,7 +198,10 @@ async def test_cmc_additional_properties( chat_history.append(Message(role="user", contents=["hello world"])) openai_chat_completion = OpenAIChatCompletionClient() - await openai_chat_completion.get_response(messages=chat_history, options={"reasoning_effort": "low"}) + await cast(Any, openai_chat_completion).get_response( + messages=chat_history, + options={"reasoning_effort": "low"}, + ) mock_create.assert_awaited_once_with( model=openai_unit_test_env["OPENAI_MODEL"], stream=False, diff --git a/python/packages/openai/tests/openai/test_openai_embedding_client.py b/python/packages/openai/tests/openai/test_openai_embedding_client.py index 1690c2b70e9..8347f9fd3f1 100644 --- a/python/packages/openai/tests/openai/test_openai_embedding_client.py +++ b/python/packages/openai/tests/openai/test_openai_embedding_client.py @@ -4,6 +4,7 @@ import inspect import os +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -186,7 +187,7 @@ async def test_openai_base64_decoding(openai_unit_test_env: dict[str, str]) -> N async def test_openai_error_when_no_model() -> None: - client = OpenAIEmbeddingClient.__new__(OpenAIEmbeddingClient) + client = cast(Any, object.__new__(OpenAIEmbeddingClient)) client.model = None client.client = MagicMock() client.additional_properties = {} @@ -232,7 +233,9 @@ async def test_integration_openai_get_embeddings() -> None: assert all(isinstance(v, float) for v in result[0].vector) assert result[0].model is not None assert result.usage is not None - assert result.usage["input_token_count"] > 0 + input_token_count = result.usage["input_token_count"] + assert input_token_count is not None + assert input_token_count > 0 @skip_if_openai_integration_tests_disabled diff --git a/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py b/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py index 4e7a584874d..c6f10892948 100644 --- a/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py +++ b/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py @@ -4,7 +4,8 @@ import os from functools import wraps -from typing import Any +from types import TracebackType +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest @@ -36,7 +37,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "") debug_message = f"Azure OpenAI debug: endpoint={endpoint}, model={model}, api_version={api_version}" if hasattr(exc, "add_note"): - exc.add_note(debug_message) + cast(Any, exc).add_note(debug_message) elif exc.args: exc.args = (f"{exc.args[0]}\n{debug_message}", *exc.args[1:]) else: @@ -65,7 +66,7 @@ def _create_azure_embedding_client( api_key=resolved_api_key, azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - credential=credential, + credential=cast(Any, credential), ) @@ -163,13 +164,24 @@ class TestAsyncTokenCredential(AsyncTokenCredential): async def get_token(self, *scopes: str, **kwargs: object): raise NotImplementedError + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + pass + monkeypatch.setenv("OPENAI_API_KEY", "test-dummy-key") monkeypatch.setenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") credential = TestAsyncTokenCredential() token_provider = MagicMock() with patch("azure.identity.aio.get_bearer_token_provider", return_value=token_provider) as mock_provider: - client = OpenAIEmbeddingClient(credential=credential) + client = OpenAIEmbeddingClient(credential=cast(Any, credential)) assert isinstance(client.client, AsyncAzureOpenAI) assert client.model == azure_openai_unit_test_env["AZURE_OPENAI_EMBEDDING_MODEL"] @@ -274,7 +286,7 @@ def test_init_with_azure_endpoint_still_uses_azure_client(azure_openai_unit_test @_with_azure_openai_debug() async def test_azure_openai_get_embeddings() -> None: async with AzureCliCredential() as credential: - client = _create_azure_embedding_client(credential=credential) + client = _create_azure_embedding_client(credential=cast(Any, credential)) result = await client.get_embeddings(["hello world"]) @@ -284,7 +296,9 @@ async def test_azure_openai_get_embeddings() -> None: assert all(isinstance(v, float) for v in result[0].vector) assert result[0].model is not None assert result.usage is not None - assert result.usage["input_token_count"] > 0 + input_token_count = result.usage["input_token_count"] + assert input_token_count is not None + assert input_token_count > 0 @pytest.mark.flaky @@ -293,7 +307,7 @@ async def test_azure_openai_get_embeddings() -> None: @_with_azure_openai_debug() async def test_azure_openai_get_embeddings_multiple() -> None: async with AzureCliCredential() as credential: - client = _create_azure_embedding_client(credential=credential) + client = _create_azure_embedding_client(credential=cast(Any, credential)) result = await client.get_embeddings(["hello", "world", "test"]) @@ -308,7 +322,7 @@ async def test_azure_openai_get_embeddings_multiple() -> None: @_with_azure_openai_debug() async def test_azure_openai_get_embeddings_with_dimensions() -> None: async with AzureCliCredential() as credential: - client = _create_azure_embedding_client(credential=credential) + client = _create_azure_embedding_client(credential=cast(Any, credential)) options: OpenAIEmbeddingOptions = {"dimensions": 256} result = await client.get_embeddings(["hello world"], options=options) diff --git a/python/packages/openai/tests/openai/test_openai_shared.py b/python/packages/openai/tests/openai/test_openai_shared.py index 86d43bc43b5..c3e17e8ac8b 100644 --- a/python/packages/openai/tests/openai/test_openai_shared.py +++ b/python/packages/openai/tests/openai/test_openai_shared.py @@ -2,6 +2,8 @@ from __future__ import annotations +from types import TracebackType +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest @@ -19,6 +21,17 @@ class _AsyncTokenCredentialStub(AsyncTokenCredential): async def get_token(self, *scopes: str, **kwargs: object): raise NotImplementedError + async def close(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + pass + class _TokenCredentialStub(TokenCredential): def get_token(self, *scopes: str, **kwargs: object): @@ -55,7 +68,7 @@ def test_resolve_azure_callable_token_provider_passthrough() -> None: def test_resolve_azure_invalid_credential_raises() -> None: with pytest.raises(ValueError, match="credential"): - _resolve_azure_credential_to_token_provider(object()) # type: ignore[arg-type] + _resolve_azure_credential_to_token_provider(cast(Any, object())) async def test_ensure_async_token_provider_wraps_sync_provider() -> None: diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py index a4108b23f08..16061b726b6 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_base_group_chat_orchestrator.py @@ -22,9 +22,9 @@ from ._orchestration_request_info import AgentApprovalExecutor if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) @@ -482,7 +482,7 @@ async def _send_request_to_participant( ) else: # Custom executors receive full context envelope - request = GroupChatRequestMessage(additional_instruction=additional_instruction, metadata=metadata) # type: ignore[assignment] + request = GroupChatRequestMessage(additional_instruction=additional_instruction, metadata=metadata) await ctx.send_message(request, target_id=target) await ctx.add_event( WorkflowEvent( diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py index 9db8878ef44..a1c299bf723 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_concurrent.py @@ -160,12 +160,12 @@ async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowCon # Call according to provided signature, always non-blocking for sync callbacks if self._param_count >= 2: if inspect.iscoroutinefunction(self._callback): - ret = await self._callback(results, ctx) # type: ignore[misc] + ret = await self._callback(results, ctx) else: ret = await asyncio.to_thread(self._callback, results, ctx) else: if inspect.iscoroutinefunction(self._callback): - ret = await self._callback(results) # type: ignore[misc] + ret = await self._callback(results) else: ret = await asyncio.to_thread(self._callback, results) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py index 728f3e388cc..3e5a2a97758 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_group_chat.py @@ -61,9 +61,9 @@ ) if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 65da3b8709c..4e6eb571c1b 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -63,9 +63,9 @@ ) if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) @@ -325,7 +325,7 @@ def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration new_tools.append(handoff_tool) if new_tools: - default_options["tools"] = existing_tools + new_tools # type: ignore[operator] + default_options["tools"] = existing_tools + new_tools else: default_options["tools"] = existing_tools diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index f8cbf88fd73..8d45c21e1a4 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -48,9 +48,9 @@ ) if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore # pragma: no cover + from typing_extensions import override # pragma: no cover logger = logging.getLogger(__name__) @@ -69,10 +69,10 @@ def _message_to_payload(message: Message) -> Any: if hasattr(message, "to_dict") and callable(getattr(message, "to_dict", None)): with contextlib.suppress(Exception): - return message.to_dict() # type: ignore[attr-defined] + return message.to_dict() if hasattr(message, "to_json") and callable(getattr(message, "to_json", None)): with contextlib.suppress(Exception): - json_payload = message.to_json() # type: ignore[attr-defined] + json_payload = message.to_json() if isinstance(json_payload, str): with contextlib.suppress(Exception): return json.loads(json_payload) @@ -90,7 +90,7 @@ def _message_from_payload(payload: Any) -> Message: return Message.from_dict(payload) # type: ignore[attr-defined,no-any-return] if hasattr(Message, "from_json") and isinstance(payload, str): with contextlib.suppress(Exception): - return Message.from_json(payload) # type: ignore[attr-defined,no-any-return] + return Message.from_json(payload) if isinstance(payload, dict): with contextlib.suppress(Exception): return Message(**payload) # type: ignore[arg-type] @@ -457,7 +457,7 @@ def _coerce_model(model_cls: type[T], data: dict[str, Any]) -> T: # We check with hasattr() first, so this is safe if hasattr(model_cls, "from_dict") and callable(model_cls.from_dict): # type: ignore[attr-defined] return model_cls.from_dict(data) # type: ignore[attr-defined,return-value,no-any-return] - return model_cls(**data) # type: ignore[arg-type,call-arg] + return model_cls(**data) # endregion Utilities diff --git a/python/packages/orchestrations/tests/test_concurrent.py b/python/packages/orchestrations/tests/test_concurrent.py index e6160f39db4..d489a5c9b33 100644 --- a/python/packages/orchestrations/tests/test_concurrent.py +++ b/python/packages/orchestrations/tests/test_concurrent.py @@ -22,7 +22,6 @@ ) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import ConcurrentBuilder -from typing_extensions import Never class _FakeAgentExec(Executor): @@ -158,7 +157,7 @@ async def test_concurrent_with_aggregator_executor_instance() -> None: class CustomAggregator(Executor): @handler - async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Never, str]) -> None: + async def aggregate(self, results: list[AgentExecutorResponse], ctx: WorkflowContext[Any, str]) -> None: texts: list[str] = [] for r in results: msgs: list[Message] = r.agent_response.messages diff --git a/python/packages/orchestrations/tests/test_group_chat.py b/python/packages/orchestrations/tests/test_group_chat.py index 50f58e781a2..935111ccaad 100644 --- a/python/packages/orchestrations/tests/test_group_chat.py +++ b/python/packages/orchestrations/tests/test_group_chat.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from collections.abc import AsyncIterable, Callable, Sequence from typing import Any, cast import pytest @@ -21,7 +21,6 @@ from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import ( AgentRequestInfoResponse, - BaseGroupChatOrchestrator, GroupChatBuilder, GroupChatState, MagenticContext, @@ -30,6 +29,8 @@ MagenticProgressLedgerItem, ) +from agent_framework_orchestrations import BaseGroupChatOrchestrator + class StubAgent(BaseAgent): def __init__(self, agent_name: str, reply_text: str, **kwargs: Any) -> None: @@ -43,12 +44,12 @@ def run( # type: ignore[override] stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Any: if stream: return self._run_stream_impl() return self._run_impl() - async def _run_impl(self) -> AgentResponse: + async def _run_impl(self) -> AgentResponse[Any]: response = Message(role="assistant", contents=[self._reply_text], author_name=self.name) return AgentResponse(messages=[response]) @@ -71,21 +72,21 @@ async def get_response( class StubManagerAgent(Agent): def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="manager_agent", description="Stub manager") + super().__init__(client=cast(Any, MockChatClient()), name="manager_agent", description="Stub manager") self._call_count = 0 - async def run( + async def run( # type: ignore[override] # ty: ignore[invalid-method-override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, session: AgentSession | None = None, **kwargs: Any, - ) -> AgentResponse: + ) -> AgentResponse[Any]: if self._call_count == 0: self._call_count += 1 # First call: select the agent (using AgentOrchestrationOutput format) payload = {"terminate": False, "reason": "Selecting agent", "next_speaker": "agent", "final_message": None} - return AgentResponse( + return AgentResponse[Any]( messages=[ Message( role="assistant", @@ -108,7 +109,7 @@ async def run( "next_speaker": None, "final_message": "agent manager final", } - return AgentResponse( + return AgentResponse[Any]( messages=[ Message( role="assistant", @@ -129,16 +130,18 @@ class ConcatenatedJsonManagerAgent(Agent): """Manager agent that emits concatenated JSON in a single assistant message.""" def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="concat_manager", description="Concatenated JSON manager") + super().__init__( + client=cast(Any, MockChatClient()), name="concat_manager", description="Concatenated JSON manager" + ) self._call_count = 0 - async def run( + async def run( # type: ignore[override] # ty: ignore[invalid-method-override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, session: AgentSession | None = None, **kwargs: Any, - ) -> AgentResponse: + ) -> AgentResponse[Any]: if self._call_count == 0: self._call_count += 1 return AgentResponse( @@ -349,9 +352,7 @@ class AgentWithoutName(BaseAgent): def __init__(self) -> None: super().__init__(name="", description="test") - def run( - self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any - ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: + def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -360,7 +361,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: return _stream() return self._run_impl() - async def _run_impl(self) -> AgentResponse: + async def _run_impl(self) -> AgentResponse[Any]: return AgentResponse(messages=[]) agent = AgentWithoutName() @@ -936,16 +937,16 @@ class DynamicManagerAgent(Agent): """Manager agent that dynamically selects from available participants.""" def __init__(self) -> None: - super().__init__(client=MockChatClient(), name="dynamic_manager", description="Dynamic manager") + super().__init__(client=cast(Any, MockChatClient()), name="dynamic_manager", description="Dynamic manager") self._call_count = 0 - async def run( + async def run( # type: ignore[override] # ty: ignore[invalid-method-override] self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, *, session: AgentSession | None = None, **kwargs: Any, - ) -> AgentResponse: + ) -> AgentResponse[Any]: if self._call_count == 0: self._call_count += 1 payload = { @@ -954,7 +955,7 @@ async def run( "next_speaker": "alpha", "final_message": None, } - return AgentResponse( + return AgentResponse[Any]( messages=[ Message( role="assistant", @@ -976,7 +977,7 @@ async def run( "next_speaker": None, "final_message": "dynamic manager final", } - return AgentResponse( + return AgentResponse[Any]( messages=[ Message( role="assistant", diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 33eed344066..21964ebddb4 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -11,6 +11,7 @@ Agent, AgentResponse, AgentResponseUpdate, + ChatOptions, ChatResponse, ChatResponseUpdate, Content, @@ -44,6 +45,14 @@ from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff +def _as_handoff_agent(agent: Any) -> Agent: + return cast(Agent, agent) + + +def _as_handoff_agents(*agents: Any) -> list[Agent]: + return [_as_handoff_agent(agent) for agent in agents] + + # region unit tests class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock chat client for testing handoff workflows.""" @@ -97,7 +106,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: response_format = options.get("response_format") - output_format_type = response_format if isinstance(response_format, type) else None + output_format_type: Any = response_format if isinstance(response_format, type) else None return ChatResponse.from_updates(updates, output_format_type=output_format_type) return ResponseStream(_stream(), finalizer=_finalize) @@ -229,10 +238,10 @@ async def test_handoff(): # between all agents. workflow = ( HandoffBuilder( - participants=[triage, specialist, escalation], + participants=_as_handoff_agents(triage, specialist, escalation), termination_condition=lambda conv: sum(1 for m in conv if m.role == "user") >= 2, ) - .with_start_agent(triage) + .with_start_agent(_as_handoff_agent(triage)) .build() ) @@ -274,8 +283,8 @@ async def test_resume_keeps_prior_user_context_for_same_agent() -> None: require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False) - .with_start_agent(refund_agent) + HandoffBuilder(participants=_as_handoff_agents(refund_agent), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(refund_agent)) .build() ) @@ -372,7 +381,9 @@ async def _get() -> ChatResponse: require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build() + HandoffBuilder(participants=_as_handoff_agents(agent), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(agent)) + .build() ) first_events = await _drain(workflow.run("start", stream=True)) @@ -476,7 +487,9 @@ async def _get() -> ChatResponse: require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build() + HandoffBuilder(participants=_as_handoff_agents(agent), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(agent)) + .build() ) first_events = await _drain(workflow.run("start", stream=True)) @@ -553,8 +566,8 @@ async def _get() -> ChatResponse: ) workflow = ( - HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False) - .with_start_agent(triage) + HandoffBuilder(participants=_as_handoff_agents(triage, specialist), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(triage)) .build() ) @@ -682,8 +695,10 @@ async def _get() -> ChatResponse: require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False) - .with_start_agent(refund_agent) + HandoffBuilder( + participants=_as_handoff_agents(refund_agent, order_agent), termination_condition=lambda _: False + ) + .with_start_agent(_as_handoff_agent(refund_agent)) .build() ) @@ -719,19 +734,20 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() -> context_providers=[triage_history], require_per_service_call_history_persistence=True, ) + specialist_options: ChatOptions = {"tool_choice": "none"} specialist = Agent( id="specialist", name="specialist", client=MockChatClient(name="specialist"), - default_options={"tool_choice": "none"}, + default_options=specialist_options, require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[triage, specialist], termination_condition=lambda _: False) - .with_start_agent(triage) - .add_handoff(triage, [specialist]) - .add_handoff(specialist, [triage]) + HandoffBuilder(participants=_as_handoff_agents(triage, specialist), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(triage)) + .add_handoff(_as_handoff_agent(triage), _as_handoff_agents(specialist)) + .add_handoff(_as_handoff_agent(specialist), _as_handoff_agents(triage)) .build() ) @@ -739,7 +755,7 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() -> executor = workflow.executors[resolve_agent_id(triage)] assert isinstance(executor, HandoffAgentExecutor) - assert executor._agent.require_per_service_call_history_persistence is True + assert cast(Agent, executor._agent).require_per_service_call_history_persistence is True provider_state = executor._session.state[triage_history.source_id] stored_messages = await triage_history.get_messages( @@ -766,25 +782,26 @@ async def tracking_middleware(context: FunctionInvocationContext, call_next): middleware=[tracking_middleware], require_per_service_call_history_persistence=True, ) + agent_b_options: ChatOptions = {"tool_choice": "none"} agent_b = Agent( id="agent_b", name="agent_b", client=MockChatClient(name="agent_b"), - default_options={"tool_choice": "none"}, + default_options=agent_b_options, require_per_service_call_history_persistence=True, ) workflow = ( - HandoffBuilder(participants=[agent_a, agent_b], termination_condition=lambda _: False) - .with_start_agent(agent_a) - .add_handoff(agent_a, [agent_b]) - .add_handoff(agent_b, [agent_a]) + HandoffBuilder(participants=_as_handoff_agents(agent_a, agent_b), termination_condition=lambda _: False) + .with_start_agent(_as_handoff_agent(agent_a)) + .add_handoff(_as_handoff_agent(agent_a), _as_handoff_agents(agent_b)) + .add_handoff(_as_handoff_agent(agent_b), _as_handoff_agents(agent_a)) .build() ) executor = workflow.executors[resolve_agent_id(agent_a)] assert isinstance(executor, HandoffAgentExecutor) - cloned_middleware = executor._agent.middleware or [] + cloned_middleware = cast(Agent, executor._agent).middleware or [] assert tracking_middleware in cloned_middleware, "User function middleware should be preserved on cloned agent" @@ -833,7 +850,7 @@ async def test_autonomous_mode_yields_output_without_user_request(): workflow = ( HandoffBuilder( - participants=[triage, specialist], + participants=_as_handoff_agents(triage, specialist), # This termination condition ensures the workflow runs through both agents. # First message is the user message to triage, second is triage's response, which # is a handoff to specialist, third is specialist's response that should not request @@ -841,7 +858,7 @@ async def test_autonomous_mode_yields_output_without_user_request(): # again and will trigger termination. termination_condition=lambda conv: len(conv) >= 4, ) - .with_start_agent(triage) + .with_start_agent(_as_handoff_agent(triage)) # Since specialist has no handoff, the specialist will be generating normal responses. # With autonomous mode, this should continue until the termination condition is met. .with_autonomous_mode( @@ -875,8 +892,8 @@ async def test_autonomous_mode_resumes_user_input_on_turn_limit(): worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[triage, worker], termination_condition=lambda conv: False) - .with_start_agent(triage) + HandoffBuilder(participants=_as_handoff_agents(triage, worker), termination_condition=lambda conv: False) + .with_start_agent(_as_handoff_agent(triage)) .with_autonomous_mode(agents=[worker], turn_limits={resolve_agent_id(worker): 2}) .build() ) @@ -893,7 +910,7 @@ def test_build_fails_without_start_agent(): specialist = MockHandoffAgent(name="specialist") with pytest.raises(ValueError, match=r"Must call with_start_agent\(...\) before building the workflow."): - HandoffBuilder(participants=[triage, specialist]).build() + HandoffBuilder(participants=_as_handoff_agents(triage, specialist)).build() def test_build_fails_without_participants(): @@ -916,8 +933,8 @@ async def async_termination(conv: list[Message]) -> bool: worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[coordinator, worker], termination_condition=async_termination) - .with_start_agent(coordinator) + HandoffBuilder(participants=_as_handoff_agents(coordinator, worker), termination_condition=async_termination) + .with_start_agent(_as_handoff_agent(coordinator)) .build() ) @@ -977,12 +994,12 @@ async def _get() -> ChatResponse: ) workflow = ( HandoffBuilder( - participants=[agent], + participants=_as_handoff_agents(agent), termination_condition=lambda conv: any( message.role == "assistant" and "case complete." in (message.text or "").lower() for message in conv ), ) - .with_start_agent(agent) + .with_start_agent(_as_handoff_agent(agent)) .build() ) @@ -1064,7 +1081,7 @@ async def before_run(self, **kwargs: Any) -> None: assert context_provider in agent.context_providers, "Original agent should have context provider" # Build handoff workflow - this should clone the agent and preserve context_providers - workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build() + workflow = HandoffBuilder(participants=_as_handoff_agents(agent)).with_start_agent(_as_handoff_agent(agent)).build() # Run workflow with a simple message to trigger context provider await _drain(workflow.run("Test message", stream=True)) @@ -1084,9 +1101,9 @@ def test_handoff_builder_accepts_all_instances_in_add_handoff(): # This should work - all instances with participants builder = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .with_start_agent(triage) - .add_handoff(triage, [specialist_a, specialist_b]) + HandoffBuilder(participants=_as_handoff_agents(triage, specialist_a, specialist_b)) + .with_start_agent(_as_handoff_agent(triage)) + .add_handoff(_as_handoff_agent(triage), _as_handoff_agents(specialist_a, specialist_b)) ) workflow = builder.build() @@ -1144,7 +1161,9 @@ def test_handoff_builder_rejects_agents_without_per_service_call_history_persist agent_with_flag = MockHandoffAgent(name="has_flag") # MockHandoffAgent sets flag to True with pytest.raises(ValueError, match="require_per_service_call_history_persistence"): - HandoffBuilder(participants=[agent_without_flag, agent_with_flag]).with_start_agent(agent_with_flag).build() + HandoffBuilder(participants=_as_handoff_agents(agent_without_flag, agent_with_flag)).with_start_agent( + _as_handoff_agent(agent_with_flag) + ).build() def test_handoff_builder_rejects_non_agent_supports_agent_run(): @@ -1167,10 +1186,10 @@ def get_session(self, *, service_session_id, **kwargs): return AgentSession(service_session_id=service_session_id) fake = FakeAgentRun("a", "A") - assert isinstance(fake, SupportsAgentRun) + assert isinstance(fake, SupportsAgentRun) # pyrefly: ignore[unsafe-overlap] with pytest.raises(TypeError, match="Participants must be Agent instances"): - HandoffBuilder().participants([fake]) + HandoffBuilder().participants([cast(Any, fake)]) # endregion @@ -1200,7 +1219,7 @@ async def test_simple_handoff_workflow(store: bool) -> None: client = FoundryChatClient( project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], model=os.environ["FOUNDRY_MODEL"], - credential=AzureCliCredential(), + credential=cast(Any, AzureCliCredential()), ) triage_agent = Agent( @@ -1210,7 +1229,7 @@ async def test_simple_handoff_workflow(store: bool) -> None: "based on the problem described." ), name="triage_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), require_per_service_call_history_persistence=True, ) @@ -1218,19 +1237,19 @@ async def test_simple_handoff_workflow(store: bool) -> None: client=client, instructions="You process refund requests. Ask user the ID of the order they want refunded.", name="refund_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder( - participants=[triage_agent, refund_agent], + participants=_as_handoff_agents(triage_agent, refund_agent), termination_condition=lambda conversation: ( # We terminate after triage hands off to refund to test handoff works len(conversation) > 0 and conversation[-1].author_name == refund_agent.name ), ) - .with_start_agent(triage_agent) + .with_start_agent(_as_handoff_agent(triage_agent)) .build() ) @@ -1256,7 +1275,7 @@ async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> client = FoundryChatClient( project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], model=os.environ["FOUNDRY_MODEL"], - credential=AzureCliCredential(), + credential=cast(Any, AzureCliCredential()), ) triage_agent = Agent( @@ -1266,7 +1285,7 @@ async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> "based on the problem described." ), name="triage_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), require_per_service_call_history_persistence=True, ) @@ -1274,13 +1293,13 @@ async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> client=client, instructions="You process refund requests. Ask user the ID of the order they want refunded.", name="refund_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder( - participants=[triage_agent, refund_agent], + participants=_as_handoff_agents(triage_agent, refund_agent), termination_condition=lambda conversation: ( # We terminate after the refund agent request user input and the user provides # a response. There will be two user messages in the conversation at that point @@ -1289,7 +1308,7 @@ async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> len([message for message in conversation if message.role == "user"]) == 2 ), ) - .with_start_agent(triage_agent) + .with_start_agent(_as_handoff_agent(triage_agent)) .build() ) @@ -1333,7 +1352,7 @@ async def test_simple_handoff_workflow_with_approval_request(store: bool) -> Non client = FoundryChatClient( project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], model=os.environ["FOUNDRY_MODEL"], - credential=AzureCliCredential(), + credential=cast(Any, AzureCliCredential()), ) triage_agent = Agent( @@ -1343,7 +1362,7 @@ async def test_simple_handoff_workflow_with_approval_request(store: bool) -> Non "based on the problem described." ), name="triage_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), require_per_service_call_history_persistence=True, ) @@ -1351,7 +1370,7 @@ async def test_simple_handoff_workflow_with_approval_request(store: bool) -> Non client=client, instructions="You process refund requests. Ask user the ID of the order they want refunded.", name="refund_agent", - default_options={"store": store}, + default_options=cast(Any, {"store": store}), tools=[process_refund], require_per_service_call_history_persistence=True, ) @@ -1359,9 +1378,9 @@ async def test_simple_handoff_workflow_with_approval_request(store: bool) -> Non # This workflow will be terminated manually workflow = ( HandoffBuilder( - participants=[triage_agent, refund_agent], + participants=_as_handoff_agents(triage_agent, refund_agent), ) - .with_start_agent(triage_agent) + .with_start_agent(_as_handoff_agent(triage_agent)) .build() ) diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index 615ba998bcc..af411cacb8d 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -2,7 +2,7 @@ import logging import sys -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Sequence from dataclasses import dataclass from typing import Any, ClassVar, cast @@ -156,11 +156,11 @@ def run( # type: ignore[override] stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Any: if stream: return self._run_stream() - async def _run() -> AgentResponse: + async def _run() -> AgentResponse[Any]: response = Message("assistant", [self._reply_text], author_name=self.name) return AgentResponse(messages=[response]) @@ -473,17 +473,17 @@ def run( stream: bool = False, session: Any = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Any: if stream: return self._run_stream() - async def _run() -> AgentResponse: + async def _run() -> AgentResponse[Any]: return AgentResponse(messages=[Message("assistant", ["ok"])]) return _run() async def _run_stream(self) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(message_deltas=[Message("assistant", ["ok"])]) + yield AgentResponseUpdate(contents=[Content.from_text(text="ok")]) async def test_standard_manager_plan_and_replan_via_complete_monkeypatch(): @@ -496,7 +496,7 @@ async def fake_complete_plan(messages: list[Message], **kwargs: Any) -> Message: return Message("assistant", ["GIVEN OR VERIFIED FACTS\n- fact1"]) # First, patch to produce facts then plan - mgr._complete = fake_complete_plan # type: ignore[attr-defined] + mgr._complete = fake_complete_plan # type: ignore[method-assign] # ty: ignore[invalid-assignment] ctx = MagenticContext(task="T", participant_descriptions={"A": "desc"}) combined = await mgr.plan(ctx.clone()) @@ -511,7 +511,7 @@ async def fake_complete_replan(messages: list[Message], **kwargs: Any) -> Messag return Message("assistant", ["- new step"]) return Message("assistant", ["GIVEN OR VERIFIED FACTS\n- updated"]) - mgr._complete = fake_complete_replan # type: ignore[attr-defined] + mgr._complete = fake_complete_replan # type: ignore[method-assign] # ty: ignore[invalid-assignment] combined2 = await mgr.replan(ctx.clone()) assert "updated" in combined2.text or "new step" in combined2.text @@ -531,7 +531,7 @@ async def fake_complete_ok(messages: list[Message], **kwargs: Any) -> Message: ) return Message("assistant", [json_text]) - mgr._complete = fake_complete_ok # type: ignore[attr-defined] + mgr._complete = fake_complete_ok # type: ignore[method-assign] # ty: ignore[invalid-assignment] ledger = await mgr.create_progress_ledger(ctx.clone()) assert ledger.next_speaker.answer == "alice" @@ -539,7 +539,7 @@ async def fake_complete_ok(messages: list[Message], **kwargs: Any) -> Message: async def fake_complete_bad(messages: list[Message], **kwargs: Any) -> Message: return Message("assistant", ["not-json"]) - mgr._complete = fake_complete_bad # type: ignore[attr-defined] + mgr._complete = fake_complete_bad # type: ignore[method-assign] # ty: ignore[invalid-assignment] with pytest.raises(RuntimeError): await mgr.create_progress_ledger(ctx.clone()) @@ -583,11 +583,11 @@ class StubThreadAgent(BaseAgent): def __init__(self, name: str | None = None) -> None: super().__init__(name=name or "agentA") - def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs) -> Any: # type: ignore[override] if stream: return self._run_stream() - async def _run(): + async def _run() -> AgentResponse[Any]: return AgentResponse(messages=[Message("assistant", ["thread-ok"], author_name=self.name)]) return _run() @@ -611,11 +611,11 @@ def __init__(self) -> None: super().__init__(name="agentA") self.client = StubAssistantsClient() # type name contains 'AssistantsClient' - def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): # type: ignore[override] + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs) -> Any: # type: ignore[override] if stream: return self._run_stream() - async def _run(): + async def _run() -> AgentResponse[Any]: return AgentResponse(messages=[Message("assistant", ["assistants-ok"], author_name=self.name)]) return _run() @@ -1193,10 +1193,10 @@ def run( stream: bool = False, session: Any = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + ) -> Any: captured_sessions.append(session) - async def _run() -> AgentResponse: + async def _run() -> AgentResponse[Any]: return AgentResponse(messages=[Message("assistant", ["ok"])]) return _run() diff --git a/python/packages/orchestrations/tests/test_orchestration_intermediate_vs_terminal.py b/python/packages/orchestrations/tests/test_orchestration_intermediate_vs_terminal.py index 78086d9d412..ece85161958 100644 --- a/python/packages/orchestrations/tests/test_orchestration_intermediate_vs_terminal.py +++ b/python/packages/orchestrations/tests/test_orchestration_intermediate_vs_terminal.py @@ -14,10 +14,11 @@ from __future__ import annotations from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any, ClassVar, Literal, overload +from typing import Any, ClassVar, Literal, cast, overload import pytest from agent_framework import ( + Agent, AgentResponse, AgentResponseUpdate, AgentRunInputs, @@ -41,6 +42,14 @@ ) +def _as_handoff_agent(agent: Any) -> Agent: + return cast(Agent, agent) + + +def _as_handoff_agents(*agents: Any) -> list[Agent]: + return [_as_handoff_agent(agent) for agent in agents] + + class _EchoAgent(BaseAgent): """Minimal non-streaming agent that returns a single assistant message.""" @@ -98,7 +107,7 @@ async def test_sequential_default_only_terminator_is_output() -> None: b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c)).build() output_events: list[Any] = [] intermediate_events: list[Any] = [] @@ -122,7 +131,7 @@ async def test_sequential_output_from_designates_workflow_output_participants() b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], output_from=["A", "B", "C"]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), output_from=["A", "B", "C"]).build() result = await workflow.run("hello") outputs = result.get_outputs() assert len(outputs) == 3 @@ -134,7 +143,7 @@ async def test_sequential_intermediate_output_from_surface_as_intermediate() -> b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], intermediate_output_from=[a, "B"]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), intermediate_output_from=[a, "B"]).build() output_executors: set[str] = set() intermediate_executors: set[str] = set() @@ -162,7 +171,7 @@ async def test_sequential_intermediate_can_demote_default_terminator() -> None: b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], intermediate_output_from=["C"]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), intermediate_output_from=["C"]).build() output_executors: set[str] = set() intermediate_executors: set[str] = set() @@ -184,7 +193,7 @@ async def test_sequential_get_outputs_returns_terminator_only() -> None: a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = SequentialBuilder(participants=[a, b]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b)).build() result = await workflow.run("hi") outputs = result.get_outputs() assert len(outputs) == 1 @@ -201,7 +210,7 @@ async def test_concurrent_default_only_aggregator_is_output() -> None: a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = ConcurrentBuilder(participants=[a, b]).build() + workflow = ConcurrentBuilder(participants=_as_handoff_agents(a, b)).build() output_events: list[Any] = [] intermediate_events: list[Any] = [] @@ -223,7 +232,7 @@ async def test_concurrent_output_from_designates_workflow_output_participants() a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = ConcurrentBuilder(participants=[a, b], output_from=[a, "B"]).build() + workflow = ConcurrentBuilder(participants=_as_handoff_agents(a, b), output_from=[a, "B"]).build() result = await workflow.run("hello") outputs = result.get_outputs() assert len(outputs) == 3 @@ -234,7 +243,7 @@ async def test_concurrent_intermediate_output_from_surface_as_intermediate() -> a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = ConcurrentBuilder(participants=[a, b], intermediate_output_from=["A", b]).build() + workflow = ConcurrentBuilder(participants=_as_handoff_agents(a, b), intermediate_output_from=["A", b]).build() output_executors: set[str] = set() intermediate_executors: set[str] = set() @@ -260,7 +269,7 @@ async def test_sequential_default_as_agent_forwards_original_content_types() -> b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c)).build() agent = workflow.as_agent("seq") response = await agent.run("hi") @@ -268,7 +277,7 @@ async def test_sequential_default_as_agent_forwards_original_content_types() -> text_contents = [c for m in response.messages for c in m.contents if c.type == "text"] reasoning_contents = [c for m in response.messages for c in m.contents if c.type == "text_reasoning"] - assert any("C reply" in c.text for c in text_contents) + assert any("C reply" in (c.text or "") for c in text_contents) assert not reasoning_contents @@ -279,12 +288,12 @@ async def test_sequential_as_agent_output_from_all_text() -> None: b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], output_from=["A", "B", "C"]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), output_from=["A", "B", "C"]).build() agent = workflow.as_agent("seq") response = await agent.run("hi") text_contents = [c for m in response.messages for c in m.contents if c.type == "text"] - text = " ".join(c.text for c in text_contents) + text = " ".join(c.text or "" for c in text_contents) assert "A reply" in text assert "B reply" in text assert "C reply" in text @@ -297,16 +306,16 @@ async def test_sequential_as_agent_intermediate_output_from_keeps_text_content() b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], intermediate_output_from=["A", "B"]).build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), intermediate_output_from=["A", "B"]).build() agent = workflow.as_agent("seq") response = await agent.run("hi") text_contents = [c for m in response.messages for c in m.contents if c.type == "text"] reasoning_contents = [c for m in response.messages for c in m.contents if c.type == "text_reasoning"] - assert any("C reply" in c.text for c in text_contents) - assert any("A reply" in c.text for c in text_contents) - assert any("B reply" in c.text for c in text_contents) + assert any("C reply" in (c.text or "") for c in text_contents) + assert any("A reply" in (c.text or "") for c in text_contents) + assert any("B reply" in (c.text or "") for c in text_contents) assert not reasoning_contents @@ -321,7 +330,7 @@ async def test_concurrent_default_as_agent_participants_keep_text_content() -> N a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = ConcurrentBuilder(participants=[a, b]).build() + workflow = ConcurrentBuilder(participants=_as_handoff_agents(a, b)).build() agent = workflow.as_agent("concurrent") response = await agent.run("hi") @@ -329,8 +338,8 @@ async def test_concurrent_default_as_agent_participants_keep_text_content() -> N text_contents = [c for m in response.messages for c in m.contents if c.type == "text"] reasoning_contents = [c for m in response.messages for c in m.contents if c.type == "text_reasoning"] - assert not any("A reply" in c.text for c in reasoning_contents) - assert not any("B reply" in c.text for c in reasoning_contents) + assert not any("A reply" in (c.text or "") for c in reasoning_contents) + assert not any("B reply" in (c.text or "") for c in reasoning_contents) # The aggregator's default-yielded AgentResponse passes through as text content. assert text_contents, "expected at least one terminal text content from the aggregator" @@ -365,7 +374,7 @@ async def test_group_chat_default_only_orchestrator_is_output() -> None: beta = _EchoAgent(name="beta") workflow = GroupChatBuilder( - participants=[alpha, beta], + participants=_as_handoff_agents(alpha, beta), max_rounds=2, selection_func=_two_step_selector(), ).build() @@ -393,7 +402,7 @@ async def test_group_chat_output_from_designates_workflow_output_participants() beta = _EchoAgent(name="beta") workflow = GroupChatBuilder( - participants=[alpha, beta], + participants=_as_handoff_agents(alpha, beta), max_rounds=2, selection_func=_two_step_selector(), output_from=[alpha, "beta"], @@ -413,7 +422,7 @@ async def test_group_chat_intermediate_output_from_surface_as_intermediate() -> beta = _EchoAgent(name="beta") workflow = GroupChatBuilder( - participants=[alpha, beta], + participants=_as_handoff_agents(alpha, beta), max_rounds=2, selection_func=_two_step_selector(), intermediate_output_from=["alpha", beta], @@ -471,7 +480,9 @@ def _inner_get_response(self, **kwargs: Any) -> Any: # pragma: no cover - never require_per_service_call_history_persistence=True, ) - workflow = HandoffBuilder(participants=[alpha, beta]).with_start_agent(alpha).build() + workflow = ( + HandoffBuilder(participants=_as_handoff_agents(alpha, beta)).with_start_agent(_as_handoff_agent(alpha)).build() + ) designated = {ex.id for ex in workflow.get_output_executors()} assert "alpha" in designated, f"alpha must be designated; got {designated}" @@ -506,7 +517,11 @@ def _inner_get_response(self, **kwargs: Any) -> Any: # pragma: no cover - never require_per_service_call_history_persistence=True, ) - workflow = HandoffBuilder(participants=[alpha, beta], output_from=["alpha"]).with_start_agent(alpha).build() + workflow = ( + HandoffBuilder(participants=_as_handoff_agents(alpha, beta), output_from=["alpha"]) + .with_start_agent(_as_handoff_agent(alpha)) + .build() + ) assert {ex.id for ex in workflow.get_output_executors()} == {"alpha"} @@ -539,7 +554,9 @@ def _inner_get_response(self, **kwargs: Any) -> Any: # pragma: no cover - never beta = Agent(name="beta", id="beta", client=_StubClient(), require_per_service_call_history_persistence=True) workflow = ( - HandoffBuilder(participants=[alpha, beta], intermediate_output_from=["alpha"]).with_start_agent(alpha).build() + HandoffBuilder(participants=_as_handoff_agents(alpha, beta), intermediate_output_from=["alpha"]) + .with_start_agent(_as_handoff_agent(alpha)) + .build() ) # alpha is implicitly removed from the default-final set; beta remains final. @@ -592,7 +609,7 @@ def test_magentic_builder_default_only_manager_designated() -> None: manager = _StubMagenticManager() alpha = _EchoAgent(name="alpha") - workflow = MagenticBuilder(participants=[alpha], manager=manager).build() + workflow = MagenticBuilder(participants=_as_handoff_agents(alpha), manager=manager).build() designated = {ex.id for ex in workflow.get_output_executors()} assert "magentic_orchestrator" in designated, f"manager must be designated; got {designated}" @@ -604,7 +621,7 @@ def test_magentic_builder_output_from_designates_workflow_output_participants() manager = _StubMagenticManager() alpha = _EchoAgent(name="alpha") - workflow = MagenticBuilder(participants=[alpha], manager=manager, output_from=["alpha"]).build() + workflow = MagenticBuilder(participants=_as_handoff_agents(alpha), manager=manager, output_from=["alpha"]).build() designated = {ex.id for ex in workflow.get_output_executors()} assert {"magentic_orchestrator", "alpha"}.issubset(designated) @@ -614,7 +631,9 @@ def test_magentic_builder_intermediate_output_from_designates_intermediate_worke manager = _StubMagenticManager() alpha = _EchoAgent(name="alpha") - workflow = MagenticBuilder(participants=[alpha], manager=manager, intermediate_output_from=[alpha]).build() + workflow = MagenticBuilder( + participants=_as_handoff_agents(alpha), manager=manager, intermediate_output_from=[alpha] + ).build() assert {ex.id for ex in workflow.get_output_executors()} == {"magentic_orchestrator"} assert {ex.id for ex in workflow.get_intermediate_executors()} == {"alpha"} @@ -625,7 +644,7 @@ def test_sequential_output_from_all_selects_all_participants() -> None: b = _EchoAgent(name="B") c = _EchoAgent(name="C") - workflow = SequentialBuilder(participants=[a, b, c], output_from="all").build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b, c), output_from="all").build() assert {ex.id for ex in workflow.get_output_executors()} == {"A", "B", "C"} @@ -636,7 +655,7 @@ def test_sequential_intermediate_output_from_all_other_selects_non_outputs() -> c = _EchoAgent(name="C") workflow = SequentialBuilder( - participants=[a, b, c], output_from=["C"], intermediate_output_from="all_other" + participants=_as_handoff_agents(a, b, c), output_from=["C"], intermediate_output_from="all_other" ).build() assert {ex.id for ex in workflow.get_output_executors()} == {"C"} @@ -647,7 +666,7 @@ def test_sequential_all_other_with_omitted_output_from_selects_all_intermediate( a = _EchoAgent(name="A") b = _EchoAgent(name="B") - workflow = SequentialBuilder(participants=[a, b], intermediate_output_from="all_other").build() + workflow = SequentialBuilder(participants=_as_handoff_agents(a, b), intermediate_output_from="all_other").build() assert {ex.id for ex in workflow.get_output_executors()} == set() assert {ex.id for ex in workflow.get_intermediate_executors()} == {"A", "B"} @@ -659,16 +678,20 @@ def test_sequential_all_other_with_omitted_output_from_selects_all_intermediate( def _build_sequential_with_designation(**kwargs: Any) -> None: - SequentialBuilder(participants=[_EchoAgent(name="alpha"), _EchoAgent(name="beta")], **kwargs).build() + SequentialBuilder( + participants=_as_handoff_agents(_EchoAgent(name="alpha"), _EchoAgent(name="beta")), **kwargs + ).build() def _build_concurrent_with_designation(**kwargs: Any) -> None: - ConcurrentBuilder(participants=[_EchoAgent(name="alpha"), _EchoAgent(name="beta")], **kwargs).build() + ConcurrentBuilder( + participants=_as_handoff_agents(_EchoAgent(name="alpha"), _EchoAgent(name="beta")), **kwargs + ).build() def _build_group_chat_with_designation(**kwargs: Any) -> None: GroupChatBuilder( - participants=[_EchoAgent(name="alpha"), _EchoAgent(name="beta")], + participants=_as_handoff_agents(_EchoAgent(name="alpha"), _EchoAgent(name="beta")), max_rounds=1, selection_func=_two_step_selector(), **kwargs, @@ -676,7 +699,9 @@ def _build_group_chat_with_designation(**kwargs: Any) -> None: def _build_magentic_with_designation(**kwargs: Any) -> None: - MagenticBuilder(participants=[_EchoAgent(name="alpha")], manager=_StubMagenticManager(), **kwargs).build() + MagenticBuilder( + participants=_as_handoff_agents(_EchoAgent(name="alpha")), manager=_StubMagenticManager(), **kwargs + ).build() def _build_handoff_with_designation(**kwargs: Any) -> None: @@ -706,7 +731,9 @@ def _inner_get_response(self, **kwargs: Any) -> Any: # pragma: no cover - never client=_StubClient(), require_per_service_call_history_persistence=True, ) - HandoffBuilder(participants=[alpha, beta], **kwargs).with_start_agent(alpha).build() + HandoffBuilder(participants=_as_handoff_agents(alpha, beta), **kwargs).with_start_agent( + _as_handoff_agent(alpha) + ).build() @pytest.mark.parametrize( diff --git a/python/packages/orchestrations/tests/test_orchestration_request_info.py b/python/packages/orchestrations/tests/test_orchestration_request_info.py index b7b266073b5..1ab40cd78bc 100644 --- a/python/packages/orchestrations/tests/test_orchestration_request_info.py +++ b/python/packages/orchestrations/tests/test_orchestration_request_info.py @@ -3,7 +3,7 @@ """Unit tests for orchestration request info support.""" from collections.abc import AsyncIterable -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -11,6 +11,7 @@ AgentResponse, AgentResponseUpdate, AgentSession, + Content, Message, SupportsAgentRun, ) @@ -114,7 +115,7 @@ async def test_request_info_handler(self): executor = AgentRequestInfoExecutor(id="test_executor") agent_response = AgentResponse(messages=[Message(role="assistant", contents=["Agent response"])]) - agent_response = AgentExecutorResponse( + executor_response = AgentExecutorResponse( executor_id="test_agent", agent_response=agent_response, full_conversation=agent_response.messages, @@ -123,9 +124,9 @@ async def test_request_info_handler(self): ctx = MagicMock(spec=WorkflowContext) ctx.request_info = AsyncMock() - await executor.request_info(agent_response, ctx) + await executor.request_info(executor_response, ctx) - ctx.request_info.assert_called_once_with(agent_response, AgentRequestInfoResponse) + ctx.request_info.assert_called_once_with(executor_response, AgentRequestInfoResponse) @pytest.mark.asyncio async def test_handle_request_info_response_with_messages(self): @@ -206,21 +207,25 @@ async def run( messages: str | Message | list[str] | list[Message] | None = None, *, stream: bool = False, - thread: AgentSession | None = None, + session: AgentSession | None = None, **kwargs: Any, - ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: + ) -> Any: """Dummy run method.""" if stream: return self._run_stream_impl() return AgentResponse(messages=[Message(role="assistant", contents=["Test response"])]) async def _run_stream_impl(self) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(messages=[Message(role="assistant", contents=["Test response stream"])]) + yield AgentResponseUpdate(contents=[Content.from_text(text="Test response stream")]) def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" return AgentSession(**kwargs) + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Gets a conversation session for the agent.""" + return AgentSession(service_session_id=service_session_id, session_id=session_id) + class TestAgentApprovalExecutor: """Tests for AgentApprovalExecutor.""" @@ -229,7 +234,7 @@ def test_initialization(self): """Test that AgentApprovalExecutor initializes correctly.""" agent = _TestAgent(id="test_id", name="test_agent", description="Test agent description") - executor = AgentApprovalExecutor(agent) + executor = AgentApprovalExecutor(cast(SupportsAgentRun, agent)) assert executor.id == "test_agent" assert executor.description == "Test agent description" @@ -238,7 +243,7 @@ def test_builds_workflow_with_agent_and_request_info_executors(self): """Test that the internal workflow is created successfully.""" agent = _TestAgent(id="test_id", name="test_agent", description="Test description") - executor = AgentApprovalExecutor(agent) + executor = AgentApprovalExecutor(cast(SupportsAgentRun, agent)) # Verify the executor has a workflow assert executor.workflow is not None @@ -248,6 +253,6 @@ def test_propagate_request_enabled(self): """Test that AgentApprovalExecutor has propagate_request enabled.""" agent = _TestAgent(id="test_id", name="test_agent", description="Test description") - executor = AgentApprovalExecutor(agent) + executor = AgentApprovalExecutor(cast(SupportsAgentRun, agent)) assert executor._propagate_request is True # type: ignore diff --git a/python/packages/orchestrations/tests/test_sequential.py b/python/packages/orchestrations/tests/test_sequential.py index d119a201202..f119fd2c1c0 100644 --- a/python/packages/orchestrations/tests/test_sequential.py +++ b/python/packages/orchestrations/tests/test_sequential.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Awaitable +from collections.abc import AsyncIterable, Awaitable, Sequence from typing import Any, Literal, overload import pytest @@ -22,7 +22,6 @@ ) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage from agent_framework.orchestrations import SequentialBuilder -from typing_extensions import Never class _EchoAgent(BaseAgent): @@ -75,7 +74,7 @@ class _SummarizerTerminator(Executor): async def summarize( self, agent_response: AgentExecutorResponse, - ctx: WorkflowContext[Never, AgentResponse], + ctx: WorkflowContext[Any, AgentResponse], ) -> None: conversation = agent_response.full_conversation or [] user_texts = [m.text for m in conversation if m.role == "user"] @@ -356,7 +355,8 @@ def run( ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: captured: list[Message] = [] if messages: - for m in messages: # type: ignore[union-attr] + message_items = messages if isinstance(messages, Sequence) and not isinstance(messages, str) else [messages] + for m in message_items: if isinstance(m, Message): captured.append(m) elif isinstance(m, str): @@ -459,10 +459,10 @@ async def test_sequential_request_info_last_participant_emits_output() -> None: request_events.append(ev) # Approve each agent in sequence until the workflow completes + output_events: list[Any] = [] while request_events: responses = {req.request_id: AgentRequestInfoResponse.approve() for req in request_events} request_events = [] - output_events: list[Any] = [] async for ev in wf.run(stream=True, responses=responses): if ev.type == "request_info" and isinstance(ev.data, AgentExecutorResponse): request_events.append(ev) diff --git a/python/packages/purview/agent_framework_purview/_client.py b/python/packages/purview/agent_framework_purview/_client.py index 43c8adce4e5..9208d02f557 100644 --- a/python/packages/purview/agent_framework_purview/_client.py +++ b/python/packages/purview/agent_framework_purview/_client.py @@ -72,9 +72,9 @@ async def _get_token(self, *, tenant_id: str | None = None) -> str: # Callable token provider — returns a token string directly if callable(cred) and not isinstance(cred, (TokenCredential, AsyncTokenCredential)): result = cred() - return await result if inspect.isawaitable(result) else result # type: ignore[return-value] + return await result if inspect.isawaitable(result) else result scopes = get_purview_scopes(self._settings) - token = cred.get_token(*scopes, tenant_id=tenant_id) # type: ignore[union-attr] + token = cred.get_token(*scopes, tenant_id=tenant_id) token = await token if inspect.isawaitable(token) else token return token.token @@ -203,7 +203,7 @@ async def _post( raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}") if resp.status_code == 402: if self._settings.get("ignore_payment_required", False): - return response_type() # type: ignore[call-arg] + return response_type() raise PurviewPaymentRequiredError(f"Payment required {resp.status_code}: {resp.text}") if resp.status_code == 429: raise PurviewRateLimitError(f"Rate limited {resp.status_code}: {resp.text}") @@ -217,7 +217,7 @@ async def _post( try: # Prefer pydantic-style model_validate if present, else fall back to constructor. model_validate = getattr(response_type, "model_validate", None) - response_obj = model_validate(data) if callable(model_validate) else response_type(**data) # type: ignore[call-arg] + response_obj = model_validate(data) if callable(model_validate) else response_type(**data) # Extract correlation_id from response headers if response object supports it if "client-request-id" in resp.headers and hasattr(response_obj, "correlation_id"): diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index cf98294f333..b605f4c4b2e 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -70,7 +70,7 @@ async def process( self, context: AgentContext, call_next: Callable[[], Awaitable[None]], - ) -> None: # type: ignore[override] + ) -> None: resolved_user_id: str | None = None session_id: str | None = None try: @@ -182,7 +182,7 @@ async def process( self, context: ChatContext, call_next: Callable[[], Awaitable[None]], - ) -> None: # type: ignore[override] + ) -> None: resolved_user_id: str | None = None session_id: str | None = None try: diff --git a/python/packages/purview/agent_framework_purview/_models.py b/python/packages/purview/agent_framework_purview/_models.py index 503871deef7..40df9d78cfe 100644 --- a/python/packages/purview/agent_framework_purview/_models.py +++ b/python/packages/purview/agent_framework_purview/_models.py @@ -122,7 +122,7 @@ def deserialize_flag( if flag_value == enum_cls(0): none_member = mapping.get("none") if none_member is not None: - return none_member # type: ignore[return-value,index] + return none_member return flag_value @@ -250,13 +250,13 @@ def model_dump_json(self, *, by_alias: bool = True, exclude_none: bool = True, * return json.dumps(self.model_dump(by_alias=by_alias, exclude_none=exclude_none, **kwargs)) @classmethod - def model_validate(cls: type[AliasSerializableT], value: MutableMapping[str, Any]) -> AliasSerializableT: # type: ignore[name-defined] + def model_validate(cls: type[AliasSerializableT], value: MutableMapping[str, Any]) -> AliasSerializableT: return cls(**value) # ------------------------------------------------------------------ # Override to handle alias emission # ------------------------------------------------------------------ - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: base = SerializationMixin.to_dict(self, exclude=exclude, exclude_none=exclude_none) # For Graph API models, remove the auto-generated 'type' field if it's in DEFAULT_EXCLUDE @@ -390,7 +390,7 @@ def __init__( super().__init__(**kwargs) self.name = name self.version = version - self.application_location = application_location # type: ignore[assignment] + self.application_location = application_location class DlpActionInfo(_AliasSerializable): @@ -507,7 +507,7 @@ def __init__(self, data: bytes, data_type: str = "microsoft.graph.binaryContent" super().__init__(data_type=data_type, **kwargs) self.data = data - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: import base64 base = super().to_dict(exclude=exclude, exclude_none=exclude_none) @@ -567,9 +567,9 @@ def __init__( # determine by type? fall back to text content c_type = content.get("@odata.type") or content.get("data_type") if c_type and "binary" in str(c_type): - content = PurviewBinaryContent(**content) # type: ignore[arg-type] + content = PurviewBinaryContent(**content) else: - content = PurviewTextContent(**content) # type: ignore[arg-type] + content = PurviewTextContent(**content) accessed_list: list[AccessedResourceDetails] | None = None if accessed_resources: accessed_list = [ @@ -586,7 +586,7 @@ def __init__( # Call parent without explicit params with aliases super().__init__(data_type=data_type, **kwargs) self.identifier = identifier - self.content = content # type: ignore[assignment] + self.content = content self.name = name self.correlation_id = correlation_id self.sequence_number = sequence_number @@ -647,10 +647,10 @@ def __init__( # Call parent without explicit params with aliases super().__init__(**kwargs) self.content_entries = entries - self.activity_metadata = activity_metadata # type: ignore[assignment] - self.device_metadata = device_metadata # type: ignore[assignment] - self.integrated_app_metadata = integrated_app_metadata # type: ignore[assignment] - self.protected_app_metadata = protected_app_metadata # type: ignore[assignment] + self.activity_metadata = activity_metadata + self.device_metadata = device_metadata + self.integrated_app_metadata = integrated_app_metadata + self.protected_app_metadata = protected_app_metadata # -------------------------------------------------------------------------------------- @@ -684,7 +684,7 @@ def __init__( # Call parent without explicit params with aliases super().__init__(**kwargs) - self.content_to_process = content_to_process # type: ignore[assignment] + self.content_to_process = content_to_process self.user_id = user_id self.tenant_id = tenant_id self.correlation_id = correlation_id @@ -737,7 +737,7 @@ def __init__( super().__init__(**kwargs) self.user_id = user_id self.tenant_id = tenant_id - self.activities = activities # type: ignore[assignment] + self.activities = activities self.locations = locations self.pivot_on = pivot_on self.device_metadata = device_metadata @@ -745,7 +745,7 @@ def __init__( self.correlation_id = correlation_id self.scope_identifier = scope_identifier - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # Get base dict (activities will be missing because Flag isn't JSON-serializable) base = super().to_dict(exclude=exclude, exclude_none=exclude_none) @@ -793,7 +793,7 @@ def __init__( super().__init__(**kwargs) self.id = id or str(uuid4()) self.user_id = user_id - self.content_to_process = content_to_process # type: ignore[assignment] + self.content_to_process = content_to_process self.tenant_id = tenant_id self.scope_identifier = scope_identifier self.correlation_id = correlation_id @@ -913,12 +913,12 @@ def __init__( # Call parent without explicit params with aliases super().__init__(**kwargs) - self.activities = activities # type: ignore[assignment] + self.activities = activities self.locations = converted_locations self.policy_actions = converted_policy_actions self.execution_mode = execution_mode - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # type: ignore[override] + def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: # Get base dict (activities will be missing because Flag isn't JSON-serializable) base = super().to_dict(exclude=exclude, exclude_none=exclude_none) @@ -991,7 +991,7 @@ def __init__( error = ErrorDetails(**error) super().__init__(status_code=status_code, error=error, correlation_id=correlation_id, **kwargs) self.status_code = status_code - self.error = error # type: ignore[assignment] + self.error = error self.correlation_id = correlation_id diff --git a/python/packages/purview/agent_framework_purview/_settings.py b/python/packages/purview/agent_framework_purview/_settings.py index 9581d041fbc..dae0d2f53bc 100644 --- a/python/packages/purview/agent_framework_purview/_settings.py +++ b/python/packages/purview/agent_framework_purview/_settings.py @@ -8,7 +8,7 @@ if sys.version_info >= (3, 11): from typing import TypedDict # pragma: no cover else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover class PurviewLocationType(str, Enum): diff --git a/python/packages/purview/tests/purview/test_chat_middleware.py b/python/packages/purview/tests/purview/test_chat_middleware.py index 4134b2f0f2b..3c6056d3443 100644 --- a/python/packages/purview/tests/purview/test_chat_middleware.py +++ b/python/packages/purview/tests/purview/test_chat_middleware.py @@ -2,10 +2,11 @@ """Tests for Purview chat middleware.""" from dataclasses import dataclass +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import ChatContext, Message, MiddlewareTermination +from agent_framework import ChatContext, ChatResponse, Message, MiddlewareTermination from azure.core.credentials import AccessToken from agent_framework_purview import PurviewChatPolicyMiddleware, PurviewSettings @@ -37,7 +38,9 @@ def chat_context(self) -> ChatContext: client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - return ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + return ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) async def test_initialization(self, middleware: PurviewChatPolicyMiddleware) -> None: assert middleware._client is not None @@ -53,16 +56,13 @@ async def mock_next() -> None: nonlocal next_called next_called = True - class Result: - def __init__(self): - self.messages = [Message(role="assistant", contents=["Hi there"])] - - chat_context.result = Result() + chat_context.result = ChatResponse(messages=[Message(role="assistant", contents=["Hi there"])]) await middleware.process(chat_context, mock_next) assert next_called assert mock_proc.call_count == 2 - assert chat_context.result.messages[0].role == "assistant" + result = cast(ChatResponse[Any], chat_context.result) + assert result.messages[0].role == "assistant" async def test_blocks_prompt(self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext) -> None: with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): @@ -74,7 +74,8 @@ async def mock_next() -> None: # should not run await middleware.process(chat_context, mock_next) assert chat_context.result assert hasattr(chat_context.result, "messages") - msg = chat_context.result.messages[0] + result = cast(ChatResponse[Any], chat_context.result) + msg = result.messages[0] assert msg.role in ("system", "system") assert "blocked" in msg.text.lower() @@ -89,15 +90,12 @@ async def side_effect(messages, activity, session_id=None, user_id=None): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): async def mock_next() -> None: - class Result: - def __init__(self): - self.messages = [Message(role="assistant", contents=["Sensitive output"])] # pragma: no cover - - chat_context.result = Result() + chat_context.result = ChatResponse(messages=[Message(role="assistant", contents=["Sensitive output"])]) await middleware.process(chat_context, mock_next) assert call_state["count"] == 2 - msgs = getattr(chat_context.result, "messages", None) or chat_context.result + result = cast(ChatResponse[Any], chat_context.result) + msgs = result.messages first_msg = msgs[0] assert first_msg.role in ("system", "system") assert "blocked" in first_msg.text.lower() @@ -107,7 +105,7 @@ async def test_streaming_skips_post_check(self, middleware: PurviewChatPolicyMid chat_options = MagicMock() chat_options.model = "test-model" streaming_context = ChatContext( - client=client, + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options, stream=True, @@ -154,7 +152,7 @@ async def test_chat_middleware_uses_consistent_user_id( self, middleware: PurviewChatPolicyMiddleware, chat_context: ChatContext ) -> None: """Test that the same user_id from pre-check is used in post-check.""" - captured_user_ids = [] + captured_user_ids: list[str | None] = [] async def mock_process_messages(messages, activity, session_id=None, user_id=None): captured_user_ids.append(user_id) @@ -187,7 +185,9 @@ async def test_chat_middleware_handles_payment_required_pre_check(self, mock_cre client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) async def mock_process_messages(*args, **kwargs): raise PurviewPaymentRequiredError("Payment required") @@ -211,7 +211,9 @@ async def test_chat_middleware_handles_payment_required_post_check(self, mock_cr client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) call_count = 0 @@ -242,7 +244,9 @@ async def test_chat_middleware_ignores_payment_required_when_configured(self, mo client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) async def mock_process_messages(*args, **kwargs): raise PurviewPaymentRequiredError("Payment required") @@ -267,7 +271,7 @@ async def test_chat_middleware_handles_result_without_messages_attribute( async def mock_next() -> None: # Set result to something without messages attribute - chat_context.result = "Some string result" + chat_context.result = cast(Any, "Some string result") await middleware.process(chat_context, mock_next) @@ -282,7 +286,9 @@ async def test_chat_middleware_with_ignore_exceptions(self, mock_credential: Asy client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) async def mock_process_messages(*args, **kwargs): raise ValueError("Some error") @@ -309,7 +315,9 @@ async def test_chat_middleware_raises_on_pre_check_exception_when_ignore_excepti client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")): @@ -329,7 +337,9 @@ async def test_chat_middleware_raises_on_post_check_exception_when_ignore_except client = DummyChatClient() chat_options = MagicMock() chat_options.model = "test-model" - context = ChatContext(client=client, messages=[Message(role="user", contents=["Hello"])], options=chat_options) + context = ChatContext( + client=cast(Any, client), messages=[Message(role="user", contents=["Hello"])], options=chat_options + ) call_count = 0 @@ -357,7 +367,7 @@ async def test_chat_middleware_uses_conversation_id_from_options( chat_client = DummyChatClient() messages = [Message(role="user", contents=["Hello"])] options = {"conversation_id": "conv-123", "model": "test-model"} - context = ChatContext(client=chat_client, messages=messages, options=options) + context = ChatContext(client=cast(Any, chat_client), messages=messages, options=options) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -378,7 +388,7 @@ async def test_chat_middleware_passes_none_session_id_when_options_missing( """Test that session_id is None when options don't contain conversation_id.""" chat_client = DummyChatClient() messages = [Message(role="user", contents=["Hello"])] - context = ChatContext(client=chat_client, messages=messages, options=None) + context = ChatContext(client=cast(Any, chat_client), messages=messages, options=None) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: @@ -397,7 +407,7 @@ async def test_chat_middleware_session_id_used_in_post_check(self, middleware: P chat_client = DummyChatClient() messages = [Message(role="user", contents=["Hello"])] options = {"conversation_id": "conv-999"} - context = ChatContext(client=chat_client, messages=messages, options=options) + context = ChatContext(client=cast(Any, chat_client), messages=messages, options=options) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: diff --git a/python/packages/purview/tests/purview/test_middleware.py b/python/packages/purview/tests/purview/test_middleware.py index 274e4878b15..5b624367696 100644 --- a/python/packages/purview/tests/purview/test_middleware.py +++ b/python/packages/purview/tests/purview/test_middleware.py @@ -2,6 +2,7 @@ """Tests for Purview middleware.""" +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -83,9 +84,10 @@ async def mock_next() -> None: assert not next_called assert context.result is not None - assert len(context.result.messages) == 1 - assert context.result.messages[0].role == "system" - assert "blocked by policy" in context.result.messages[0].text.lower() + result = cast(AgentResponse[Any], context.result) + assert len(result.messages) == 1 + assert result.messages[0].role == "system" + assert "blocked by policy" in result.messages[0].text.lower() async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None: """Test middleware checks agent response for policy violations.""" @@ -110,9 +112,10 @@ async def mock_next() -> None: assert call_count == 2 assert context.result is not None - assert len(context.result.messages) == 1 - assert context.result.messages[0].role == "system" - assert "blocked by policy" in context.result.messages[0].text.lower() + result = cast(AgentResponse[Any], context.result) + assert len(result.messages) == 1 + assert result.messages[0].role == "system" + assert "blocked by policy" in result.messages[0].text.lower() async def test_middleware_handles_result_without_messages( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock @@ -126,7 +129,7 @@ async def test_middleware_handles_result_without_messages( with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): async def mock_next() -> None: - context.result = "Some non-standard result" + context.result = cast(Any, "Some non-standard result") await middleware.process(context, mock_next) diff --git a/python/packages/purview/tests/purview/test_processor.py b/python/packages/purview/tests/purview/test_processor.py index 0cc9d7a8a99..45ba3c5a277 100644 --- a/python/packages/purview/tests/purview/test_processor.py +++ b/python/packages/purview/tests/purview/test_processor.py @@ -3,6 +3,7 @@ """Tests for Purview processor.""" import asyncio +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -14,6 +15,7 @@ DlpAction, DlpActionInfo, ProcessContentResponse, + ProtectionScopeState, RestrictionAction, ) from agent_framework_purview._processor import ScopedContentProcessor, _is_valid_guid @@ -103,9 +105,11 @@ async def test_process_messages_blocks_content( mock_request = process_content_request_factory("Sensitive content") - mock_response = ProcessContentResponse(**{ - "policyActions": [DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)] - }) + mock_response = ProcessContentResponse( + policy_actions=cast( + Any, [DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restrictionAction=RestrictionAction.BLOCK)] + ) + ) with ( patch.object(processor, "_map_messages", return_value=([mock_request], "user-123")), @@ -169,7 +173,7 @@ async def test_check_applicable_scopes_no_scopes( from agent_framework_purview._models import ProtectionScopesResponse request = process_content_request_factory() - response = ProtectionScopesResponse(**{"value": None}) + response = ProtectionScopesResponse(scopes=None) should_process, actions, execution_mode = processor._check_applicable_scopes(request, response) @@ -194,12 +198,17 @@ async def test_check_applicable_scopes_with_block_action( "@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id", }) - scope = PolicyScope(**{ - "policyActions": [block_action], - "activities": ProtectionScopeActivities.UPLOAD_TEXT, - "locations": [scope_location], - }) - response = ProtectionScopesResponse(**{"value": [scope]}) + scope = PolicyScope( + **cast( + Any, + { + "policyActions": [block_action], + "activities": ProtectionScopeActivities.UPLOAD_TEXT, + "locations": [scope_location], + }, + ) + ) + response = ProtectionScopesResponse(scopes=cast(Any, [scope])) should_process, actions, execution_mode = processor._check_applicable_scopes(request, response) @@ -257,9 +266,11 @@ async def test_process_with_scopes_calls_client_methods( request = process_content_request_factory() - mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []})) + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(scopes=[])) mock_client.process_content = AsyncMock( - return_value=ProcessContentResponse(**{"id": "response-123", "protectionScopeState": "notModified"}) + return_value=ProcessContentResponse( + id="response-123", protection_scope_state=ProtectionScopeState.NOT_MODIFIED + ) ) mock_client.send_content_activities = AsyncMock(return_value=ContentActivitiesResponse(**{"error": None})) @@ -350,13 +361,29 @@ async def test_process_with_scopes_ignores_unexpected_cached_value_type( request = process_content_request_factory() mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []})) + # Return a valid, inline scope so we stay on the normal (non-background) path. + scope_location = PolicyLocation(**{ + "@odata.type": "microsoft.graph.policyLocationApplication", + "value": "app-id", + }) + scope = PolicyScope( + **cast( + Any, + { + "activities": ProtectionScopeActivities.UPLOAD_TEXT, + "locations": [scope_location], + "execution_mode": ExecutionMode.EVALUATE_INLINE, + }, + ) + ) + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(scopes=cast(Any, [scope]))) mock_client.process_content = AsyncMock( - return_value=ProcessContentResponse(**{"id": "ok", "protectionScopeState": "notModified"}) + return_value=ProcessContentResponse(id="ok", protection_scope_state=ProtectionScopeState.NOT_MODIFIED) ) # First cache read is the tenant payment key (None). Second is the scopes cache (corrupt value). - processor._cache.get = AsyncMock(side_effect=[None, "corrupt-value"]) # type: ignore[method-assign] - processor._cache.set = AsyncMock() # type: ignore[method-assign] + cast(Any, processor._cache).get = AsyncMock(side_effect=[None, "corrupt-value"]) + cast(Any, processor._cache).set = AsyncMock() response = await processor._process_with_scopes(request) @@ -373,7 +400,7 @@ async def test_process_with_scopes_uses_tenant_payment_exception_cache( request = process_content_request_factory() - processor._cache.get = AsyncMock(return_value=PurviewPaymentRequiredError("Payment required")) # type: ignore[method-assign] + cast(Any, processor._cache).get = AsyncMock(return_value=PurviewPaymentRequiredError("Payment required")) with pytest.raises(PurviewPaymentRequiredError): await processor._process_with_scopes(request) @@ -389,15 +416,15 @@ async def test_process_content_background_retries_on_modified_state( mock_client.process_content = AsyncMock( side_effect=[ - ProcessContentResponse(**{"id": "r1", "protectionScopeState": "modified"}), - ProcessContentResponse(**{"id": "r2", "protectionScopeState": "notModified"}), + ProcessContentResponse(id="r1", protection_scope_state=ProtectionScopeState.MODIFIED), + ProcessContentResponse(id="r2", protection_scope_state=ProtectionScopeState.NOT_MODIFIED), ] ) - processor._cache.remove = AsyncMock() # type: ignore[method-assign] + cast(Any, processor._cache).remove = AsyncMock() await processor._process_content_background(request, cache_key="purview:protection_scopes:abc") - processor._cache.remove.assert_called_once_with("purview:protection_scopes:abc") + cast(Any, processor._cache).remove.assert_called_once_with("purview:protection_scopes:abc") assert mock_client.process_content.call_count == 2 async def test_background_scope_refresh_caches_payment_required( @@ -819,4 +846,5 @@ async def test_custom_cache_provider_used(self, mock_client: AsyncMock, settings processor = ScopedContentProcessor(mock_client, settings, cache_provider=custom_cache) assert processor._cache is custom_cache + assert isinstance(processor._cache, InMemoryCacheProvider) assert processor._cache._default_ttl == 60 diff --git a/python/packages/purview/tests/purview/test_purview_client.py b/python/packages/purview/tests/purview/test_purview_client.py index a61724d3f0d..133194f1824 100644 --- a/python/packages/purview/tests/purview/test_purview_client.py +++ b/python/packages/purview/tests/purview/test_purview_client.py @@ -3,6 +3,7 @@ """Tests for Purview client.""" from collections.abc import AsyncGenerator +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -199,7 +200,7 @@ async def test_get_protection_scopes_uses_etag_header_when_present(self, client: user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789" ) - response_obj = ProtectionScopesResponse(**{"scopeIdentifier": "scope-from-body", "value": []}) + response_obj = ProtectionScopesResponse(scope_identifier="scope-from-body", scopes=[]) with patch.object( client, @@ -219,7 +220,7 @@ async def test_post_402_returns_empty_response_when_ignore_payment_required_enab settings = PurviewSettings(app_name="Test App", ignore_payment_required=True) client = PurviewClient(mock_credential, settings) - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) resp = httpx.Response(402, text="Payment required", request=httpx.Request("POST", "http://test")) @@ -234,7 +235,7 @@ async def test_post_sets_request_and_response_correlation_id(self, client: Purvi from agent_framework_purview._models import ProcessContentResponse # correlation_id is optional and should be auto-generated when empty - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) request.correlation_id = "" # force auto-generation branch captured_headers: dict[str, str] = {} @@ -270,7 +271,7 @@ async def test_process_content_402_returns_empty_when_ignored(self, mock_credent settings = PurviewSettings(app_name="Test App", ignore_payment_required=True) client = PurviewClient(mock_credential, settings) - req = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + req = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 402 @@ -286,7 +287,7 @@ async def test_post_sets_correlation_id_attribute_on_recording_span(self, client """Test that correlation_id is added to the active span when recording is enabled.""" from agent_framework_purview._models import ProcessContentResponse - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) request.correlation_id = "corr-123" class RecordingSpan: @@ -325,7 +326,7 @@ class DummyResponse: def __init__(self, **data): self.data = data - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) request.correlation_id = "corr-123" with patch.object( @@ -364,7 +365,7 @@ async def test_send_content_activities_success(self, client: PurviewClient, cont async def test_post_handles_invalid_json_response_body(self, client: PurviewClient) -> None: """Test that invalid JSON bodies fall back to an empty dict.""" - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) request.correlation_id = "corr-123" mock_response = MagicMock(spec=httpx.Response) @@ -385,7 +386,7 @@ class BadResponseType: def model_validate(cls, value): raise RuntimeError("boom") - request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[]) + request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=cast(Any, [])) request.correlation_id = "corr-123" mock_response = MagicMock(spec=httpx.Response) @@ -417,7 +418,7 @@ async def test_rate_limit_error(self, client: PurviewClient) -> None: request = ProcessContentRequest( user_id="test-user", tenant_id="test-tenant", - content_to_process=[], + content_to_process=cast(Any, []), correlation_id="test-correlation-id", ) @@ -437,7 +438,7 @@ async def test_generic_request_error(self, client: PurviewClient) -> None: request = ProcessContentRequest( user_id="test-user", tenant_id="test-tenant", - content_to_process=[], + content_to_process=cast(Any, []), correlation_id="test-correlation-id", ) @@ -466,7 +467,7 @@ async def test_prefer_header_sent_when_process_inline_true( process_inline=True, ) - posted_headers = {} + posted_headers: dict[str, str] = {} mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 200 mock_response.headers = {} @@ -494,7 +495,7 @@ async def test_prefer_header_not_sent_when_process_inline_false( process_inline=False, ) - posted_headers = {} + posted_headers: dict[str, str] = {} mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 200 mock_response.headers = {} @@ -521,7 +522,7 @@ async def test_prefer_header_not_sent_when_process_inline_none( process_inline=None, ) - posted_headers = {} + posted_headers: dict[str, str] = {} mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 200 mock_response.headers = {} @@ -561,7 +562,7 @@ async def test_scope_identifier_sent_as_if_none_match_header( scope_identifier="test-scope-id", ) - posted_headers = {} + posted_headers: dict[str, str] = {} mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 200 mock_response.headers = {} diff --git a/python/packages/purview/tests/purview/test_purview_models.py b/python/packages/purview/tests/purview/test_purview_models.py index d0c968df5cb..4af581014d9 100644 --- a/python/packages/purview/tests/purview/test_purview_models.py +++ b/python/packages/purview/tests/purview/test_purview_models.py @@ -100,7 +100,9 @@ def test_content_to_process_with_nested_structures(self) -> None: assert len(content.content_entries) == 1 assert content.activity_metadata.activity == Activity.UPLOAD_TEXT - assert content.device_metadata.operating_system_specifications.operating_system_platform == "Windows" + os_specs = content.device_metadata.operating_system_specifications + assert os_specs is not None + assert os_specs.operating_system_platform == "Windows" assert content.integrated_app_metadata.name == "App" assert content.protected_app_metadata.name == "Protected" @@ -162,6 +164,7 @@ def test_process_content_response_deserialization(self) -> None: assert response.id == "response-123" assert response.protection_scope_state == "blocked" + assert response.policy_actions is not None assert len(response.policy_actions) == 1 def test_content_serialization_uses_aliases(self) -> None: diff --git a/python/packages/purview/tests/purview/test_settings.py b/python/packages/purview/tests/purview/test_settings.py index 42e03a2be31..bc0a39d2f55 100644 --- a/python/packages/purview/tests/purview/test_settings.py +++ b/python/packages/purview/tests/purview/test_settings.py @@ -32,6 +32,7 @@ def test_settings_with_custom_values(self) -> None: assert settings["graph_base_uri"] == "https://graph.microsoft-ppe.com" assert settings["tenant_id"] == "test-tenant-id" + assert settings["purview_app_location"] is not None assert settings["purview_app_location"].location_value == "app-123" @pytest.mark.parametrize( diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index 4c102f01876..2b5cf5b30b9 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -33,9 +33,9 @@ from typing_extensions import Self # pragma: no cover if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover + from typing import override # pragma: no cover else: - from typing_extensions import override # type: ignore[import] # pragma: no cover + from typing_extensions import override # pragma: no cover if TYPE_CHECKING: from agent_framework._agents import SupportsAgentRun @@ -379,12 +379,12 @@ async def _redis_search( text_scorer=text_scorer, filter_expression=combined_filter, alpha=alpha, - dtype=self.redis_vectorizer.dtype, # pyright: ignore[reportUnknownMemberType] + dtype=self.redis_vectorizer.dtype, num_results=num_results, return_fields=return_fields, stopwords=None, ) - return await self.redis_index.query(query) # type: ignore[no-any-return] + return await self.redis_index.query(query) query = TextQuery( text=q, text_field_name="content", @@ -394,7 +394,7 @@ async def _redis_search( return_fields=return_fields, stopwords=None, ) - return await self.redis_index.query(query) # type: ignore[no-any-return] + return await self.redis_index.query(query) except Exception as exc: # pragma: no cover raise IntegrationInvalidRequestException(f"Redis text search failed: {exc}") from exc diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index dbdc358a93c..ef22511faa8 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -176,7 +176,7 @@ def _deserialize_json(data: str) -> dict[str, Any]: """Deserialize a JSON string from Redis to a dict.""" import json - return json.loads(data) # type: ignore[no-any-return] + return json.loads(data) async def clear(self, session_id: str | None) -> None: """Clear all messages for a session. @@ -188,7 +188,7 @@ async def clear(self, session_id: str | None) -> None: async def aclose(self) -> None: """Close the Redis connection.""" - await self._redis_client.aclose() # type: ignore[misc] + await self._redis_client.aclose() __all__ = ["RedisHistoryProvider"] diff --git a/python/packages/redis/tests/test_providers.py b/python/packages/redis/tests/test_providers.py index 54587a55e19..74ed328b404 100644 --- a/python/packages/redis/tests/test_providers.py +++ b/python/packages/redis/tests/test_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations import json +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -101,7 +102,7 @@ def test_invalid_vectorizer_raises(self, patch_index_from_dict: MagicMock): # n from agent_framework.exceptions import AgentException with pytest.raises(AgentException, match="not a valid type"): - RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type] + RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] class TestRedisContextProviderValidateFilters: @@ -112,7 +113,7 @@ def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002 for kwargs in [{"user_id": "u"}, {"agent_id": "a"}, {"application_id": "app"}]: - provider = RedisContextProvider(source_id="ctx", **kwargs) + provider = RedisContextProvider(source_id="ctx", **cast(Any, kwargs)) provider._validate_filters() # should not raise @@ -144,7 +145,10 @@ async def test_search_results_added_to_context( ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert "ctx" in ctx.context_messages @@ -163,7 +167,10 @@ async def test_empty_input_no_search( ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_index.query.assert_not_called() @@ -182,7 +189,10 @@ async def test_before_run_searches_without_session_id( with patch.object(provider, "_redis_search", wraps=provider._redis_search) as spy: await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] spy.assert_called_once() @@ -200,7 +210,10 @@ async def test_empty_results_no_messages( ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert "ctx" not in ctx.context_messages @@ -219,7 +232,10 @@ async def test_stores_messages( ctx._response = response await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_index.load.assert_called_once() @@ -238,7 +254,10 @@ async def test_skips_empty_conversations( ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1") await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_index.load.assert_not_called() @@ -253,7 +272,10 @@ async def test_stores_partition_fields( ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1") await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] loaded = mock_index.load.call_args[0][0] @@ -489,7 +511,10 @@ async def test_before_run_loads_history(self, mock_redis_client: MagicMock): ctx = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1") await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] assert "mem" in ctx.context_messages @@ -506,7 +531,10 @@ async def test_after_run_stores_input_and_response(self, mock_redis_client: Magi ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])]) await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value @@ -524,7 +552,10 @@ async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMo ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1") await provider.after_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) + agent=cast(Any, None), + session=session, + context=ctx, + state=session.state.setdefault(provider.source_id, {}), ) # type: ignore[arg-type] mock_redis_client.pipeline.assert_not_called() diff --git a/python/packages/tools/agent_framework_tools/shell/_killtree.py b/python/packages/tools/agent_framework_tools/shell/_killtree.py index f47dadf7772..6e73bc967b7 100644 --- a/python/packages/tools/agent_framework_tools/shell/_killtree.py +++ b/python/packages/tools/agent_framework_tools/shell/_killtree.py @@ -25,12 +25,12 @@ import sys try: # pragma: no cover - importable on every platform we ship - import psutil # type: ignore[import-untyped] + import psutil _has_psutil = True except ImportError: # pragma: no cover _has_psutil = False - psutil = None # type: ignore[assignment] + psutil = None _taskkill_path: str | None = None diff --git a/python/packages/tools/tests/test_docker_shell_tool.py b/python/packages/tools/tests/test_docker_shell_tool.py index 81db60e4512..1eba1c3c370 100644 --- a/python/packages/tools/tests/test_docker_shell_tool.py +++ b/python/packages/tools/tests/test_docker_shell_tool.py @@ -185,7 +185,7 @@ def test_build_exec_argv_non_interactive_appends_dash_c(): def test_docker_shell_tool_validates_mode(): with pytest.raises(ValueError, match="mode must be"): - DockerShellTool(mode="bogus") # type: ignore[arg-type] + DockerShellTool(mode="bogus") # type: ignore[arg-type] # ty: ignore[invalid-argument-type] def test_docker_shell_tool_does_not_require_acknowledge_unsafe(): diff --git a/python/packages/tools/tests/test_shell_environment_provider.py b/python/packages/tools/tests/test_shell_environment_provider.py index 1759db153e9..f2d2ef6936d 100644 --- a/python/packages/tools/tests/test_shell_environment_provider.py +++ b/python/packages/tools/tests/test_shell_environment_provider.py @@ -294,9 +294,9 @@ def extend_instructions(self, source_id: str, instructions: Any) -> None: received.append((source_id, instructions)) await provider.before_run( - agent=None, # type: ignore[arg-type] - session=None, # type: ignore[arg-type] - context=FakeContext(), # type: ignore[arg-type] + agent=None, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] + session=None, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] + context=FakeContext(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] state={}, ) @@ -346,9 +346,9 @@ def extend_instructions(self, source_id: str, instructions: Any) -> None: received.append((source_id, instructions)) await provider.before_run( - agent=None, # type: ignore[arg-type] - session=None, # type: ignore[arg-type] - context=FakeContext(), # type: ignore[arg-type] + agent=None, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] + session=None, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] + context=FakeContext(), # type: ignore[arg-type] # ty: ignore[invalid-argument-type] state={}, ) diff --git a/python/pyproject.toml b/python/pyproject.toml index 0a4e6f34a93..c6ed714c34d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -39,6 +39,9 @@ dev = [ "pytest-retry==1.7.0", "mypy==1.20.0", "pyright==1.1.408", + "pyrefly==1.0.0", + "ty==0.0.46", + "zuban==0.8.2", "mcp[ws]==1.27.2", "opentelemetry-sdk==1.40.0", "azure-monitor-opentelemetry==1.8.8", @@ -185,56 +188,40 @@ omit = [ ] [tool.pyright] -exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"] +# Pyright is the sole SOURCE-code type checker. Tests + samples are covered by mypy, +# pyrefly, ty (and zuban) instead, so Pyright excludes them entirely -- no per-package +# test executionEnvironments are needed anymore. +exclude = ["**/tests/**", "**/ag_ui_tests/**", "samples/**", "**/.venv/**", "packages/devui/frontend/**"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false reportUnnecessaryCast = "error" -# Tests intentionally probe internal implementation details. -executionEnvironments = [ - { root = "packages/a2a/tests", reportPrivateUsage = "none" }, - { root = "packages/ag-ui/tests", reportPrivateUsage = "none" }, - { root = "packages/anthropic/tests", reportPrivateUsage = "none" }, - { root = "packages/azure-ai-search/tests", reportPrivateUsage = "none" }, - { root = "packages/azure-contentunderstanding/tests", reportPrivateUsage = "none" }, - { root = "packages/azure-cosmos/tests", reportPrivateUsage = "none" }, - { root = "packages/azurefunctions/tests", reportPrivateUsage = "none" }, - { root = "packages/bedrock/tests", reportPrivateUsage = "none" }, - { root = "packages/chatkit/tests", reportPrivateUsage = "none" }, - { root = "packages/claude/tests", reportPrivateUsage = "none" }, - { root = "packages/copilotstudio/tests", reportPrivateUsage = "none" }, - { root = "packages/core/tests", reportPrivateUsage = "none" }, - { root = "packages/declarative/tests", reportPrivateUsage = "none" }, - { root = "packages/devui/tests", reportPrivateUsage = "none" }, - { root = "packages/durabletask/tests", reportPrivateUsage = "none" }, - { root = "packages/foundry/tests", reportPrivateUsage = "none" }, - { root = "packages/foundry_local/tests", reportPrivateUsage = "none" }, - { root = "packages/github_copilot/tests", reportPrivateUsage = "none" }, - { root = "packages/lab/gaia/tests", reportPrivateUsage = "none" }, - { root = "packages/lab/lightning/tests", reportPrivateUsage = "none" }, - { root = "packages/lab/tau2/tests", reportPrivateUsage = "none" }, - { root = "packages/mem0/tests", reportPrivateUsage = "none" }, - { root = "packages/mistral/tests", reportPrivateUsage = "none" }, - { root = "packages/ollama/tests", reportPrivateUsage = "none" }, - { root = "packages/orchestrations/tests", reportPrivateUsage = "none" }, - { root = "packages/purview/tests", reportPrivateUsage = "none" }, - { root = "packages/redis/tests", reportPrivateUsage = "none" }, - { root = "tests", reportPrivateUsage = "none" }, -] - +# With MyPy off source, mypy-only ``# type: ignore`` comments are now dead weight. Flag +# them so they get removed and do not creep back in. +reportUnnecessaryTypeIgnoreComment = "error" + +# MyPy (and zuban, which reads this [mypy]-compatible config) no longer run on source +# code -- Pyright is the sole source-code type checker. These checkers run over the +# tests + samples instead, in a deliberately relaxed mode: real type errors in how the +# public API is exercised are caught, but test/sample authors are not burdened with +# annotating every function. See docs/skills/python-code-quality. [tool.mypy] plugins = ['pydantic.mypy'] -strict = true python_version = "3.10" ignore_missing_imports = true -disallow_untyped_defs = true -no_implicit_optional = true check_untyped_defs = true -warn_return_any = true -show_error_codes = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +disallow_untyped_decorators = false +no_implicit_optional = true +warn_return_any = false warn_unused_ignores = false -disallow_incomplete_defs = true -disallow_untyped_decorators = true +show_error_codes = true + +# ty (preview) over tests + samples. Relaxed: ty's defaults are gradual, and a few +# categories are too noisy on test/mock-heavy code to gate on yet. +[tool.ty.rules] +unused-type-ignore-comment = "ignore" [tool.bandit] targets = ["agent_framework"] @@ -292,11 +279,15 @@ help = "Run Pyright for -P/--package packages, use -A/--all for one aggregate sw cmd = "python scripts/workspace_poe_tasks.py pyright" [tool.poe.tasks.mypy] -help = "Run MyPy for -P/--package packages, or use -A/--all for one aggregate sweep." +help = "Run MyPy over -P/--package test suites (alias for `test-typing --checker mypy`)." cmd = "python scripts/workspace_poe_tasks.py mypy" +[tool.poe.tasks.test-typing] +help = "Run the tests/samples type checkers (mypy, pyrefly, ty, zuban) for -P/--package, -A/--all, or -S/--samples. Narrow with `--checker NAME` (repeatable)." +cmd = "python scripts/workspace_poe_tasks.py test-typing" + [tool.poe.tasks.typing] -help = "Run both MyPy and Pyright for -P/--package packages, or use -A/--all for aggregate mode." +help = "Run Pyright over source and the tests/samples checkers for -P/--package packages, or use -A/--all." cmd = "python scripts/workspace_poe_tasks.py typing" [tool.poe.tasks.samples-syntax] diff --git a/python/pyrefly.samples.toml b/python/pyrefly.samples.toml new file mode 100644 index 00000000000..7d7f49d755d --- /dev/null +++ b/python/pyrefly.samples.toml @@ -0,0 +1,27 @@ +# Basic-mode Pyrefly profile for samples. Samples are teaching code: we want to catch +# real mistakes (bad imports, wrong attribute/module access) without forcing readers to +# wade through casts and overload gymnastics just to satisfy third-party SDK stubs. +project-includes = ["samples"] +project-excludes = [ + "**/autogen-migration/**", + "**/semantic-kernel-migration/**", + "**/autogen/**", + "**/demos/**", + "**/_to_delete/**", + "**/05-end-to-end/**", + "**/harness/**", +] + +[errors] +# Signature/cast/overload noise driven mostly by third-party SDK stubs -- off for samples. +bad-argument-type = false +no-matching-overload = false +bad-return = false +bad-override-mutable-attribute = false +unexpected-keyword = false +unsupported-operation = false +missing-argument = false +bad-assignment = false +not-async = false +# Kept on (real bugs a reader would hit): missing-import, missing-attribute, +# missing-module-attribute, bad-index, bad-typed-dict-key, not-iterable. diff --git a/python/pyrefly.toml b/python/pyrefly.toml new file mode 100644 index 00000000000..355f15f2f75 --- /dev/null +++ b/python/pyrefly.toml @@ -0,0 +1,17 @@ +# Pyrefly runs over the tests + samples (not source -- Pyright owns source). +# Relaxed profile: surface real public-API type errors without forcing test/sample +# authors to fully annotate their code. Paths are passed explicitly by the workspace +# task runner; this config supplies the shared rule tuning. +project-includes = ["packages/*/tests", "packages/*/ag_ui_tests", "samples"] + +# The lab sub-packages use an editable namespace layout (agent_framework_lab_*) that +# pyrefly's import resolver does not pick up from the venv finder hook; add their roots. +search-path = [ + "packages/lab/gaia", + "packages/lab/lightning", + "packages/lab/tau2", +] + +[errors] +# Test/mock-heavy code legitimately accesses attributes dynamically. +missing-attribute = false diff --git a/python/pyrightconfig.samples.json b/python/pyrightconfig.samples.json index c2d1274e055..3d0c80919ee 100644 --- a/python/pyrightconfig.samples.json +++ b/python/pyrightconfig.samples.json @@ -11,7 +11,13 @@ "**/agent_with_foundry_tracing.py", "**/azure_responses_client_with_foundry.py" ], - "typeCheckingMode": "off", + "typeCheckingMode": "basic", "reportMissingImports": "error", + "reportMissingModuleSource": false, + "reportMissingTypeStubs": false, + "reportPrivateImportUsage": false, + "reportPrivateUsage": false, + "reportTypedDictNotRequiredAccess": false, + "reportInvalidTypeVarUse": false, "reportAttributeAccessIssue": "error" } diff --git a/python/pyrightconfig.samples.py310.json b/python/pyrightconfig.samples.py310.json index 581f856c364..48e9f0a14a7 100644 --- a/python/pyrightconfig.samples.py310.json +++ b/python/pyrightconfig.samples.py310.json @@ -12,7 +12,13 @@ "**/azure_responses_client_with_foundry.py", "**/github_copilot/**" ], - "typeCheckingMode": "off", + "typeCheckingMode": "basic", "reportMissingImports": "error", + "reportMissingModuleSource": false, + "reportMissingTypeStubs": false, + "reportPrivateImportUsage": false, + "reportPrivateUsage": false, + "reportTypedDictNotRequiredAccess": false, + "reportInvalidTypeVarUse": false, "reportAttributeAccessIssue": "error" } diff --git a/python/pyrightconfig.tests.json b/python/pyrightconfig.tests.json new file mode 100644 index 00000000000..44836976e9f --- /dev/null +++ b/python/pyrightconfig.tests.json @@ -0,0 +1,9 @@ +{ + "typeCheckingMode": "basic", + "reportMissingModuleSource": false, + "reportMissingTypeStubs": false, + "reportPrivateImportUsage": false, + "reportPrivateUsage": false, + "reportTypedDictNotRequiredAccess": false, + "reportInvalidTypeVarUse": false +} diff --git a/python/samples/02-agents/a2a/a2a_stream_reconnection.py b/python/samples/02-agents/a2a/a2a_stream_reconnection.py index c9fd0a88917..3cb254a83dc 100644 --- a/python/samples/02-agents/a2a/a2a_stream_reconnection.py +++ b/python/samples/02-agents/a2a/a2a_stream_reconnection.py @@ -2,10 +2,12 @@ import asyncio import os +from typing import cast import httpx from a2a.client import A2ACardResolver from agent_framework.a2a import A2AAgent +from agent_framework_a2a import A2AContinuationToken from dotenv import load_dotenv load_dotenv() @@ -82,9 +84,9 @@ async def main() -> None: # 4. Reconnect using the saved continuation token. # background=True is required so that in-progress task updates # surface continuation tokens (matching the A2AAgent contract). - print(f"Reconnecting with continuation token (task_id={saved_token['task_id']})...") + print("Reconnecting with continuation token...") resumed_stream = agent.run( - continuation_token=saved_token, + continuation_token=cast(A2AContinuationToken, saved_token), stream=True, background=True, ) diff --git a/python/samples/02-agents/background_responses.py b/python/samples/02-agents/background_responses.py index f3f3d7126a8..f6adc30ea3a 100644 --- a/python/samples/02-agents/background_responses.py +++ b/python/samples/02-agents/background_responses.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +from typing import cast from agent_framework import Agent -from agent_framework.openai import OpenAIChatClient +from agent_framework.openai import OpenAIChatClient, OpenAIChatOptions, OpenAIContinuationToken from dotenv import load_dotenv # Load environment variables from .env file @@ -43,7 +44,7 @@ async def non_streaming_polling() -> None: response = await agent.run( messages="Briefly explain the theory of relativity in two sentences.", session=session, - options={"background": True}, + options=OpenAIChatOptions(background=True), ) print(f"Initial status: continuation_token={'set' if response.continuation_token else 'None'}") @@ -55,7 +56,7 @@ async def non_streaming_polling() -> None: await asyncio.sleep(2) response = await agent.run( session=session, - options={"continuation_token": response.continuation_token}, + options=OpenAIChatOptions(continuation_token=cast(OpenAIContinuationToken, response.continuation_token)), ) print(f" Poll {poll_count}: continuation_token={'set' if response.continuation_token else 'None'}") @@ -75,7 +76,7 @@ async def streaming_with_resumption() -> None: messages="Briefly list three benefits of exercise.", stream=True, session=session, - options={"background": True}, + options=OpenAIChatOptions(background=True), ) # 3. Read some chunks, then simulate an interruption. @@ -96,7 +97,7 @@ async def streaming_with_resumption() -> None: stream = agent.run( stream=True, session=session, - options={"continuation_token": last_token}, + options=OpenAIChatOptions(continuation_token=cast(OpenAIContinuationToken, last_token)), ) async for update in stream: if update.text: diff --git a/python/samples/02-agents/chat_client/chat_response_cancellation.py b/python/samples/02-agents/chat_client/chat_response_cancellation.py index cd82be26020..2a4e17432b1 100644 --- a/python/samples/02-agents/chat_client/chat_response_cancellation.py +++ b/python/samples/02-agents/chat_client/chat_response_cancellation.py @@ -29,10 +29,11 @@ async def main() -> None: """ client = FoundryChatClient(credential=AzureCliCredential()) + async def get_story_response() -> None: + await client.get_response(messages=[Message(role="user", contents=["Tell me a fantasy story."])]) + try: - task = asyncio.create_task( - client.get_response(messages=[Message(role="user", contents=["Tell me a fantasy story."])]) - ) + task = asyncio.create_task(get_story_response()) await asyncio.sleep(1) task.cancel() await task diff --git a/python/samples/02-agents/context_providers/azure_ai_search/search_context_agentic.py b/python/samples/02-agents/context_providers/azure_ai_search/search_context_agentic.py index 2d57a2906be..07c69ecc3f2 100644 --- a/python/samples/02-agents/context_providers/azure_ai_search/search_context_agentic.py +++ b/python/samples/02-agents/context_providers/azure_ai_search/search_context_agentic.py @@ -100,7 +100,7 @@ async def main() -> None: credential=AzureCliCredential() if not search_key else None, mode="agentic", azure_openai_resource_url=azure_openai_resource_url, - model_deployment_name=model_deployment, + model=model_deployment, # Optional: Configure retrieval behavior knowledge_base_output_mode="extractive_data", # or "answer_synthesis" retrieval_reasoning_effort="minimal", # or "medium", "low" diff --git a/python/samples/02-agents/context_providers/redis/redis_basics.py b/python/samples/02-agents/context_providers/redis/redis_basics.py index bf9e163a49a..4a12ad1e856 100644 --- a/python/samples/02-agents/context_providers/redis/redis_basics.py +++ b/python/samples/02-agents/context_providers/redis/redis_basics.py @@ -156,18 +156,20 @@ async def main() -> None: # Use the provider's before_run/after_run API to store and retrieve messages. # In practice, the agent handles this automatically; this shows the low-level API. - from agent_framework import AgentSession, SessionContext + from typing import cast + + from agent_framework import AgentSession, SessionContext, SupportsAgentRun session = AgentSession(session_id="runA") context = SessionContext(input_messages=messages) state = session.state # Store messages via after_run - await provider.after_run(agent=None, session=session, context=context, state=state) + await provider.after_run(agent=cast(SupportsAgentRun, None), session=session, context=context, state=state) # Retrieve relevant memories via before_run query_context = SessionContext(input_messages=[Message("system", ["B: Assistant Message"])]) - await provider.before_run(agent=None, session=session, context=query_context, state=state) + await provider.before_run(agent=cast(SupportsAgentRun, None), session=session, context=query_context, state=state) # Inspect retrieved memories that would be injected into instructions # (Debug-only output so you can verify retrieval works as expected.) diff --git a/python/samples/02-agents/context_providers/simple_context_provider.py b/python/samples/02-agents/context_providers/simple_context_provider.py index dd8da8cbe6b..10bc5e2243a 100644 --- a/python/samples/02-agents/context_providers/simple_context_provider.py +++ b/python/samples/02-agents/context_providers/simple_context_provider.py @@ -60,10 +60,13 @@ async def after_run( # Update user info with extracted data with suppress(Exception): extracted = result.value - if state["user_info"].name is None and extracted.name: - state["user_info"].name = extracted.name - if state["user_info"].age is None and extracted.age: - state["user_info"].age = extracted.age + user_info = state["user_info"] + if not isinstance(extracted, UserInfo) or not isinstance(user_info, UserInfo): + return + if user_info.name is None and extracted.name: + user_info.name = extracted.name + if user_info.age is None and extracted.age: + user_info.age = extracted.age async def before_run( self, diff --git a/python/samples/02-agents/declarative/azure_openai_responses_agent.py b/python/samples/02-agents/declarative/azure_openai_responses_agent.py index 7a02c3a53da..c102bcc0364 100644 --- a/python/samples/02-agents/declarative/azure_openai_responses_agent.py +++ b/python/samples/02-agents/declarative/azure_openai_responses_agent.py @@ -31,7 +31,11 @@ async def main(): # Use response.value with try/except for safe parsing try: parsed = response.value - print("Agent response:", parsed.model_dump_json(indent=2)) + model_dump_json = getattr(parsed, "model_dump_json", None) + if callable(model_dump_json): + print("Agent response:", model_dump_json(indent=2)) + else: + print("Agent response:", response.text) except Exception: print("Agent response:", response.text) diff --git a/python/samples/02-agents/devui/agent_foundry/__init__.py b/python/samples/02-agents/devui/agent_foundry/__init__.py index 0ecbfc3802b..4ba2b2ba05b 100644 --- a/python/samples/02-agents/devui/agent_foundry/__init__.py +++ b/python/samples/02-agents/devui/agent_foundry/__init__.py @@ -2,6 +2,6 @@ """Weather agent sample for DevUI testing.""" -from .agent import agent +from .agent import agent # ty: ignore[unresolved-import] __all__ = ["agent"] diff --git a/python/samples/02-agents/devui/agent_weather/__init__.py b/python/samples/02-agents/devui/agent_weather/__init__.py index 0ecbfc3802b..4ba2b2ba05b 100644 --- a/python/samples/02-agents/devui/agent_weather/__init__.py +++ b/python/samples/02-agents/devui/agent_weather/__init__.py @@ -2,6 +2,6 @@ """Weather agent sample for DevUI testing.""" -from .agent import agent +from .agent import agent # ty: ignore[unresolved-import] __all__ = ["agent"] diff --git a/python/samples/02-agents/devui/workflow_spam/__init__.py b/python/samples/02-agents/devui/workflow_spam/__init__.py index 9801f7433a0..373b575a9b4 100644 --- a/python/samples/02-agents/devui/workflow_spam/__init__.py +++ b/python/samples/02-agents/devui/workflow_spam/__init__.py @@ -2,6 +2,6 @@ """Spam detection workflow sample for DevUI testing.""" -from .workflow import workflow +from .workflow import workflow # ty: ignore[unresolved-import] __all__ = ["workflow"] diff --git a/python/samples/02-agents/devui/workflow_with_agents/__init__.py b/python/samples/02-agents/devui/workflow_with_agents/__init__.py index 67fc70ac2f3..c8bf6639748 100644 --- a/python/samples/02-agents/devui/workflow_with_agents/__init__.py +++ b/python/samples/02-agents/devui/workflow_with_agents/__init__.py @@ -2,6 +2,6 @@ """Sequential Agents Workflow - Writer → Reviewer.""" -from .workflow import workflow +from .workflow import workflow # ty: ignore[unresolved-import] __all__ = ["workflow"] diff --git a/python/samples/02-agents/devui/workflow_with_agents/workflow.py b/python/samples/02-agents/devui/workflow_with_agents/workflow.py index 708c85a63c0..464cfb5d627 100644 --- a/python/samples/02-agents/devui/workflow_with_agents/workflow.py +++ b/python/samples/02-agents/devui/workflow_with_agents/workflow.py @@ -18,7 +18,7 @@ from typing import Any from agent_framework import Agent, AgentExecutorResponse, WorkflowBuilder -from agent_framework.openai import OpenAIChatClient +from agent_framework.openai import OpenAIChatClient, OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel @@ -98,7 +98,7 @@ def is_approved(message: Any) -> bool: "- feedback: concise, actionable feedback\n" "- clarity, completeness, accuracy, structure: individual scores (0-100)" ), - default_options={"response_format": ReviewResult}, + default_options=OpenAIChatOptions[Any](response_format=ReviewResult), ) # Create Editor agent - improves content based on feedback diff --git a/python/samples/02-agents/evaluation/evaluate_agent.py b/python/samples/02-agents/evaluation/evaluate_agent.py index d9cf9527432..4711b6bc68c 100644 --- a/python/samples/02-agents/evaluation/evaluate_agent.py +++ b/python/samples/02-agents/evaluation/evaluate_agent.py @@ -69,7 +69,7 @@ async def main() -> None: for r in results: print(f"{r.provider}: {r.passed}/{r.total} passed") for item in r.items: - print(f" [{item.status}] Q: {item.input_text[:50]} A: {item.output_text[:50]}...") + print(f" [{item.status}] Q: {(item.input_text or '')[:50]} A: {(item.output_text or '')[:50]}...") for score in item.scores: print(f" {'PASS' if score.passed else 'FAIL'} {score.name}") diff --git a/python/samples/02-agents/evaluation/evaluate_multimodal.py b/python/samples/02-agents/evaluation/evaluate_multimodal.py index f51bfc77e7d..d060c7edea3 100644 --- a/python/samples/02-agents/evaluation/evaluate_multimodal.py +++ b/python/samples/02-agents/evaluation/evaluate_multimodal.py @@ -111,7 +111,7 @@ async def main() -> None: print(f"\n{results.provider}: {results.passed}/{results.total} passed") for item in results.items: - print(f"\n [{item.status}] Q: {item.input_text[:60]}...") + print(f"\n [{item.status}] Q: {(item.input_text or '')[:60]}...") for score in item.scores: symbol = "PASS" if score.passed else "FAIL" print(f" {symbol} {score.name}: {score.score}") diff --git a/python/samples/02-agents/evaluation/evaluate_with_expected.py b/python/samples/02-agents/evaluation/evaluate_with_expected.py index a165593c184..04efcd8c7a6 100644 --- a/python/samples/02-agents/evaluation/evaluate_with_expected.py +++ b/python/samples/02-agents/evaluation/evaluate_with_expected.py @@ -66,7 +66,7 @@ async def main() -> None: for r in results: print(f"{r.provider}: {r.passed}/{r.total} passed") for item in r.items: - print(f" [{item.status}] {item.input_text} -> {item.output_text[:80]}") + print(f" [{item.status}] {item.input_text} -> {(item.output_text or '')[:80]}") if __name__ == "__main__": diff --git a/python/samples/02-agents/harness/console/agent_runner.py b/python/samples/02-agents/harness/console/agent_runner.py index 3b7c685dbd6..743ad1e1328 100644 --- a/python/samples/02-agents/harness/console/agent_runner.py +++ b/python/samples/02-agents/harness/console/agent_runner.py @@ -313,9 +313,7 @@ async def _collect_follow_up_actions( """ actions: list[FollowUpAction] = [] for observer in self._observers: - observer_actions = await observer.on_stream_complete( - self._ux, self._agent, session - ) + observer_actions = await observer.on_stream_complete(self._ux, self._agent, session) if observer_actions: actions.extend(observer_actions) return actions diff --git a/python/samples/02-agents/harness/console/app.py b/python/samples/02-agents/harness/console/app.py index c56360c661b..e2260eb2005 100644 --- a/python/samples/02-agents/harness/console/app.py +++ b/python/samples/02-agents/harness/console/app.py @@ -182,18 +182,12 @@ def __init__( if command_handlers is None: from .commands import build_default_command_handlers - self._command_handlers = build_default_command_handlers( - agent, mode_colors=mode_colors - ) + self._command_handlers = build_default_command_handlers(agent, mode_colors=mode_colors) else: self._command_handlers = command_handlers # Compute help text from command handlers - help_parts = [ - h.get_help_text() - for h in self._command_handlers - if h.get_help_text() is not None - ] + help_parts = [h.get_help_text() for h in self._command_handlers if h.get_help_text() is not None] help_text = ", ".join(help_parts) if help_parts else None # State and driver diff --git a/python/samples/02-agents/harness/console/commands/todo_handler.py b/python/samples/02-agents/harness/console/commands/todo_handler.py index 73703e6db34..e32ffd3f6a4 100644 --- a/python/samples/02-agents/harness/console/commands/todo_handler.py +++ b/python/samples/02-agents/harness/console/commands/todo_handler.py @@ -45,9 +45,7 @@ async def try_handle( ux.append_info_line("TodoProvider is not available.") return True - todos = await self._todo_provider.store.load_items( - session, source_id=self._todo_provider.source_id - ) + todos = await self._todo_provider.store.load_items(session, source_id=self._todo_provider.source_id) if not todos: ux.append_info_line("No todos yet.") diff --git a/python/samples/02-agents/harness/console/components/scroll_panel.py b/python/samples/02-agents/harness/console/components/scroll_panel.py index a9cf15a7749..35b478b54bc 100644 --- a/python/samples/02-agents/harness/console/components/scroll_panel.py +++ b/python/samples/02-agents/harness/console/components/scroll_panel.py @@ -72,7 +72,7 @@ def set_streaming_entry(self, entry: OutputEntry) -> None: # Truncate lines back to where streaming started if len(self.lines) > self._streaming_line_start: - del self.lines[self._streaming_line_start:] + del self.lines[self._streaming_line_start :] from textual.geometry import Size self.virtual_size = Size(self._widest_line_width, len(self.lines)) diff --git a/python/samples/02-agents/harness/console/observers/planning_models.py b/python/samples/02-agents/harness/console/observers/planning_models.py index 9b4a92e5757..d4c425b0783 100644 --- a/python/samples/02-agents/harness/console/observers/planning_models.py +++ b/python/samples/02-agents/harness/console/observers/planning_models.py @@ -41,8 +41,7 @@ class PlanningQuestion(BaseModel): choices: list[str] | None = Field( default=None, description=( - "For clarifications, this has a list of options that the user can " - "choose from. null for approvals." + "For clarifications, this has a list of options that the user can choose from. null for approvals." ), ) diff --git a/python/samples/02-agents/middleware/chat_middleware.py b/python/samples/02-agents/middleware/chat_middleware.py index e5604bbd6ba..acb3338e352 100644 --- a/python/samples/02-agents/middleware/chat_middleware.py +++ b/python/samples/02-agents/middleware/chat_middleware.py @@ -95,7 +95,7 @@ async def process( modified_messages.append(message) # Replace messages in context - context.messages[:] = modified_messages + context.messages = modified_messages # Continue to next middleware or AI execution await call_next() diff --git a/python/samples/02-agents/providers/amazon/bedrock_chat_client.py b/python/samples/02-agents/providers/amazon/bedrock_chat_client.py index 2c1f0b8845f..8e62073fdb7 100644 --- a/python/samples/02-agents/providers/amazon/bedrock_chat_client.py +++ b/python/samples/02-agents/providers/amazon/bedrock_chat_client.py @@ -4,7 +4,7 @@ from typing import Annotated from agent_framework import Agent, tool -from agent_framework.amazon import BedrockChatClient +from agent_framework.amazon import BedrockChatClient, BedrockChatOptions from dotenv import load_dotenv from pydantic import Field @@ -43,8 +43,8 @@ async def main() -> None: client=BedrockChatClient(), instructions="You are a concise travel assistant.", name="BedrockWeatherAgent", - tool_choice="auto", tools=[get_weather], + default_options=BedrockChatOptions(tool_choice="auto"), ) # 2. Run a query that uses the weather tool. diff --git a/python/samples/02-agents/providers/anthropic/anthropic_skills.py b/python/samples/02-agents/providers/anthropic/anthropic_skills.py index 5f4d1d40b96..f5945474bad 100644 --- a/python/samples/02-agents/providers/anthropic/anthropic_skills.py +++ b/python/samples/02-agents/providers/anthropic/anthropic_skills.py @@ -85,6 +85,8 @@ async def main() -> None: # Since I'm using the pptx skill, the files will be PowerPoint presentations print("Generated files:") for idx, file in enumerate(files): + if file.file_id is None: + continue file_content = await client.anthropic_client.beta.files.download( # type: ignore file_id=file.file_id, betas=["files-api-2025-04-14"] ) diff --git a/python/samples/02-agents/providers/anthropic/anthropic_with_shell.py b/python/samples/02-agents/providers/anthropic/anthropic_with_shell.py index 40c6aedc430..9d74a554548 100644 --- a/python/samples/02-agents/providers/anthropic/anthropic_with_shell.py +++ b/python/samples/02-agents/providers/anthropic/anthropic_with_shell.py @@ -68,7 +68,7 @@ async def main() -> None: print(f"Result: {result}\n") -async def run_with_approvals(query: str, agent: Agent) -> Any: +async def run_with_approvals(query: str, agent: Agent[Any]) -> Any: """Run the agent and handle shell approvals outside tool execution.""" current_input: str | list[Any] = query while True: @@ -79,6 +79,8 @@ async def run_with_approvals(query: str, agent: Agent) -> Any: next_input: list[Any] = [query] rejected = False for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"\nShell request: {user_input_needed.function_call.name}" f"\nArguments: {user_input_needed.function_call.arguments}" diff --git a/python/samples/02-agents/providers/custom/custom_agent.py b/python/samples/02-agents/providers/custom/custom_agent.py index 1957f2e086a..ad87bd14e73 100644 --- a/python/samples/02-agents/providers/custom/custom_agent.py +++ b/python/samples/02-agents/providers/custom/custom_agent.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import AsyncIterable -from typing import Any +from collections.abc import AsyncIterable, Awaitable +from typing import Any, Literal, overload from agent_framework import ( AgentResponse, @@ -55,6 +55,26 @@ def __init__( ) self.echo_prefix = echo_prefix + @overload + def run( + self, + messages: str | Message | list[str] | list[Message] | None = None, + *, + stream: Literal[False] = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> asyncio.Future[AgentResponse]: ... + + @overload + def run( + self, + messages: str | Message | list[str] | list[Message] | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentResponseUpdate]: ... + def run( self, messages: str | Message | list[str] | list[Message] | None = None, @@ -62,7 +82,7 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> "AsyncIterable[AgentResponseUpdate] | asyncio.Future[AgentResponse]": + ) -> "AsyncIterable[AgentResponseUpdate] | Awaitable[AgentResponse]": """Execute the agent and return a response. Args: @@ -181,7 +201,9 @@ async def main() -> None: query2 = "This is a streaming test" print(f"\nUser: {query2}") print("Agent: ", end="", flush=True) - async for chunk in echo_agent.run(query2, stream=True): + stream = echo_agent.run(query2, stream=True) + assert isinstance(stream, AsyncIterable) + async for chunk in stream: if chunk.text: print(chunk.text, end="", flush=True) print() diff --git a/python/samples/02-agents/providers/foundry/foundry_chat_client_code_interpreter_files.py b/python/samples/02-agents/providers/foundry/foundry_chat_client_code_interpreter_files.py index 6fff2c30690..cb10c36a4c6 100644 --- a/python/samples/02-agents/providers/foundry/foundry_chat_client_code_interpreter_files.py +++ b/python/samples/02-agents/providers/foundry/foundry_chat_client_code_interpreter_files.py @@ -75,7 +75,7 @@ async def main() -> None: credential=AzureCliCredential(), ) # use the openai client from the foundry client to upload files for the code interpreter tool - openai_client = client.project_client.get_openai_client() + openai_client = getattr(client.project_client, "get_openai_client")() # noqa: B009 temp_file_path, file_id = await create_sample_file_and_upload(openai_client) # Create agent with code interpreter tool with file access agent = Agent( diff --git a/python/samples/02-agents/providers/foundry/foundry_chat_client_with_hosted_mcp.py b/python/samples/02-agents/providers/foundry/foundry_chat_client_with_hosted_mcp.py index 3089fd18562..dfb0db4b8e7 100644 --- a/python/samples/02-agents/providers/foundry/foundry_chat_client_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/foundry/foundry_chat_client_with_hosted_mcp.py @@ -19,10 +19,10 @@ """ if TYPE_CHECKING: - from agent_framework import AgentSession, SupportsAgentRun + from agent_framework import AgentSession -async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun"): +async def handle_approvals_without_session(query: str, agent: Agent[Any]): """When we don't have a session, we need to ensure we return with the input, approval request and approval.""" from agent_framework import Message @@ -30,6 +30,8 @@ async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun" while len(result.user_input_requests) > 0: new_inputs: list[Any] = [query] for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" @@ -47,14 +49,16 @@ async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun" return result -async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession"): +async def handle_approvals_with_session(query: str, agent: Agent[Any], session: "AgentSession"): """Here we let the session deal with the previous responses, and we just rerun with the approval.""" - from agent_framework import Message + from agent_framework import ChatOptions, Message - result = await agent.run(query, session=session, options={"store": True}) + result = await agent.run(query, session=session, options=ChatOptions(store=True)) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" @@ -66,23 +70,25 @@ async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", s contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, session=session, options={"store": True}) + result = await agent.run(new_input, session=session, options=ChatOptions(store=True)) return result -async def handle_approvals_with_session_streaming(query: str, agent: "SupportsAgentRun", session: "AgentSession"): +async def handle_approvals_with_session_streaming(query: str, agent: Agent[Any], session: "AgentSession"): """Here we let the session deal with the previous responses, and we just rerun with the approval.""" - from agent_framework import Message + from agent_framework import ChatOptions, Message new_input: list[Message | str] = [query] new_input_added = True while new_input_added: new_input_added = False - async for update in agent.run(new_input, session=session, options={"store": True}, stream=True): + async for update in agent.run(new_input, session=session, options=ChatOptions(store=True), stream=True): if update.user_input_requests: # Reset input to only contain new approval responses for the next iteration new_input = [] for user_input_needed in update.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" diff --git a/python/samples/02-agents/providers/foundry/foundry_local_agent.py b/python/samples/02-agents/providers/foundry/foundry_local_agent.py index a758c759142..95d4e06d21e 100644 --- a/python/samples/02-agents/providers/foundry/foundry_local_agent.py +++ b/python/samples/02-agents/providers/foundry/foundry_local_agent.py @@ -5,7 +5,7 @@ import asyncio from random import randint -from typing import Annotated +from typing import Annotated, Any from agent_framework import Agent from agent_framework.foundry import FoundryLocalClient @@ -31,7 +31,7 @@ def get_weather( return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." -async def non_streaming_example(agent: Agent) -> None: +async def non_streaming_example(agent: Agent[Any]) -> None: """Example of non-streaming response (get the complete result at once).""" print("=== Non-streaming Response Example ===") @@ -41,7 +41,7 @@ async def non_streaming_example(agent: Agent) -> None: print(f"Agent: {result}\n") -async def streaming_example(agent: Agent) -> None: +async def streaming_example(agent: Agent[Any]) -> None: """Example of streaming response (get results as they are generated).""" print("=== Streaming Response Example ===") diff --git a/python/samples/02-agents/providers/foundry/foundry_prompt_agents.py b/python/samples/02-agents/providers/foundry/foundry_prompt_agents.py index 6d53891aaaa..65fb18cec47 100644 --- a/python/samples/02-agents/providers/foundry/foundry_prompt_agents.py +++ b/python/samples/02-agents/providers/foundry/foundry_prompt_agents.py @@ -114,6 +114,8 @@ async def main() -> None: # 3) Convert and publish. The version returned by Foundry includes the version label # we need when connecting back to that specific deployment. + if agent.name is None: + raise ValueError("Agent name is required to create a prompt agent version.") created = await project_client.agents.create_version( agent_name=agent.name, # note this line: diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py b/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py index 3cbfe01795b..ca90985c0a2 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_basic.py @@ -18,7 +18,7 @@ from typing import Annotated from agent_framework import tool -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import PermissionHandler from dotenv import load_dotenv from pydantic import Field @@ -43,10 +43,10 @@ async def non_streaming_example() -> None: """Example of non-streaming response (get the complete result at once).""" print("=== Non-streaming Response Example ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: @@ -60,10 +60,10 @@ async def streaming_example() -> None: """Example of streaming response (get results as they are generated).""" print("=== Streaming Response Example ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: @@ -80,10 +80,10 @@ async def runtime_options_example() -> None: """Example of overriding system message at runtime.""" print("=== Runtime Options Example ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="Always respond in exactly 3 words.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: @@ -98,15 +98,15 @@ async def runtime_options_example() -> None: # Second call overrides with runtime system_message in replace mode print("Using runtime system_message with replace mode (detailed response):") print(f"User: {query}") - result2 = await agent.run( + result2 = await agent.run( # pyright: ignore[reportCallIssue] query, - options={ - "system_message": { + options=GitHubCopilotOptions( # pyright: ignore[reportArgumentType] + system_message={ "mode": "replace", "content": "You are a weather expert. Provide detailed weather information " "with temperature, and recommendations.", } - }, + ), ) print(f"Agent: {result2}\n") diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_file_operations.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_file_operations.py index 67336259d08..7f363add9a4 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_file_operations.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_file_operations.py @@ -13,7 +13,7 @@ import asyncio -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.generated.rpc import PermissionDecisionDeniedInteractivelyByUser from copilot.session import PermissionHandler, PermissionRequestResult from copilot.session_events import PermissionRequest @@ -31,9 +31,9 @@ async def prompt_permission(request: PermissionRequest, context: dict[str, str]) async def main() -> None: print("=== GitHub Copilot Agent with File Operation Permissions ===\n") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful assistant that can read and write files.", - default_options={"on_permission_request": prompt_permission}, + default_options=GitHubCopilotOptions(on_permission_request=prompt_permission), ) async with agent: diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_function_approval.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_function_approval.py index 3348dbf1a56..272aa2acb9a 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_function_approval.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_function_approval.py @@ -31,7 +31,7 @@ from typing import Annotated from agent_framework import Content, tool -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import PermissionHandler from dotenv import load_dotenv @@ -78,13 +78,13 @@ def auto_approve(call: Content) -> bool: async def run_with_interactive_callback() -> None: """Demonstrates an interactive approval prompt before tool execution.""" print("\n=== GitHub Copilot Agent: interactive approval callback ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather assistant.", tools=[get_weather_detail], - default_options={ - "on_function_approval": prompt_for_approval, - "on_permission_request": PermissionHandler.approve_all, - }, + default_options=GitHubCopilotOptions( + on_function_approval=prompt_for_approval, + on_permission_request=PermissionHandler.approve_all, + ), ) async with agent: query = "Give me the detailed weather for Seattle." @@ -96,13 +96,13 @@ async def run_with_interactive_callback() -> None: async def run_with_auto_approve_callback() -> None: """Demonstrates a synchronous callback that always approves.""" print("\n=== GitHub Copilot Agent: synchronous auto-approve callback ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather assistant.", tools=[get_weather_detail], - default_options={ - "on_function_approval": auto_approve, - "on_permission_request": PermissionHandler.approve_all, - }, + default_options=GitHubCopilotOptions( + on_function_approval=auto_approve, + on_permission_request=PermissionHandler.approve_all, + ), ) async with agent: query = "Give me the detailed weather for Tokyo." @@ -119,10 +119,10 @@ async def run_without_callback() -> None: or try a different approach instead of silently failing. """ print("\n=== GitHub Copilot Agent: no callback configured (deny by default) ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather assistant.", tools=[get_weather_detail], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: query = "Give me the detailed weather for Paris." diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_instruction_directories.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_instruction_directories.py index 4c7ae2c1a42..1016e09a3d1 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_instruction_directories.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_instruction_directories.py @@ -21,7 +21,7 @@ import asyncio from pathlib import Path -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import PermissionHandler from dotenv import load_dotenv @@ -44,12 +44,12 @@ async def default_instructions_example() -> None: # 2. Create the agent with instruction directories in default_options. # These directories apply to every session created by this agent. - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful coding assistant.", - default_options={ - "on_permission_request": PermissionHandler.approve_all, - "instruction_directories": instruction_dirs, - }, + default_options=GitHubCopilotOptions( + on_permission_request=PermissionHandler.approve_all, + instruction_directories=instruction_dirs, + ), ) # 3. Run the agent — instruction files from those directories are loaded @@ -65,12 +65,12 @@ async def runtime_override_example() -> None: """Example of overriding instruction directories at runtime.""" print("=== Instruction Directories (Runtime Override) ===\n") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful assistant.", - default_options={ - "on_permission_request": PermissionHandler.approve_all, - "instruction_directories": ["/team/shared/instructions"], - }, + default_options=GitHubCopilotOptions( + on_permission_request=PermissionHandler.approve_all, + instruction_directories=["/team/shared/instructions"], + ), ) async with agent: @@ -85,11 +85,11 @@ async def runtime_override_example() -> None: print("Overriding with project-specific instructions...\n") query2 = "Now what instructions are you following?" print(f"User: {query2}") - result2 = await agent.run( + result2 = await agent.run( # pyright: ignore[reportCallIssue] query2, - options={ - "instruction_directories": ["/project/specific/instructions"], - }, + options=GitHubCopilotOptions( # pyright: ignore[reportArgumentType] + instruction_directories=["/project/specific/instructions"], + ), ) print(f"Agent: {result2}\n") diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_mcp.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_mcp.py index 60e4704c078..10096ea02d2 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_mcp.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_mcp.py @@ -14,7 +14,7 @@ import asyncio -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import MCPServerConfig, PermissionHandler from dotenv import load_dotenv @@ -42,12 +42,12 @@ async def main() -> None: }, } - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful assistant with access to the local filesystem and Microsoft Learn.", - default_options={ - "on_permission_request": PermissionHandler.approve_all, - "mcp_servers": mcp_servers, - }, + default_options=GitHubCopilotOptions( + on_permission_request=PermissionHandler.approve_all, + mcp_servers=mcp_servers, + ), ) async with agent: @@ -61,7 +61,10 @@ async def main() -> None: # Remote MCP calls may take longer, so increase the timeout query2 = "Search Microsoft Learn for 'Azure Functions Python' and summarize the top result" print(f"User: {query2}") - result2 = await agent.run(query2, options={"timeout": 120}) + result2 = await agent.run( # pyright: ignore[reportCallIssue] + query2, + options=GitHubCopilotOptions(timeout=120), # pyright: ignore[reportArgumentType] + ) print(f"Agent: {result2}\n") diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_multiple_permissions.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_multiple_permissions.py index 5da43b32744..4e375e181a4 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_multiple_permissions.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_multiple_permissions.py @@ -19,7 +19,7 @@ import asyncio -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import PermissionHandler, PermissionRequestResult from copilot.session_events import PermissionRequest @@ -33,9 +33,9 @@ def approve_and_log(request: PermissionRequest, context: dict[str, str]) -> Perm async def main() -> None: print("=== GitHub Copilot Agent with Multiple Permissions ===\n") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful development assistant that can read, write files and run commands.", - default_options={"on_permission_request": approve_and_log}, + default_options=GitHubCopilotOptions(on_permission_request=approve_and_log), ) async with agent: diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py index 70a0f3c8260..96ac0d7ca5c 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_session.py @@ -13,7 +13,7 @@ from typing import Annotated from agent_framework import tool -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.session import PermissionHandler from pydantic import Field @@ -34,10 +34,10 @@ async def example_with_automatic_session_creation() -> None: """Each run() without thread creates a new session.""" print("=== Automatic Session Creation Example ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: @@ -59,10 +59,10 @@ async def example_with_session_persistence() -> None: """Reuse session via thread object for multi-turn conversations.""" print("=== Session Persistence Example ===") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent: @@ -99,7 +99,7 @@ async def example_with_existing_session_id() -> None: agent1 = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent1: @@ -121,7 +121,7 @@ async def example_with_existing_session_id() -> None: agent2 = GitHubCopilotAgent( instructions="You are a helpful weather agent.", tools=[get_weather], - default_options={"on_permission_request": PermissionHandler.approve_all}, + default_options=GitHubCopilotOptions(on_permission_request=PermissionHandler.approve_all), ) async with agent2: diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_shell.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_shell.py index 66ead3bf992..0cd6ba3728a 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_shell.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_shell.py @@ -13,7 +13,7 @@ import asyncio -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.generated.rpc import PermissionDecisionUserNotAvailable from copilot.session import PermissionHandler, PermissionRequestResult from copilot.session_events import PermissionRequest @@ -33,9 +33,9 @@ def approve_and_log(request: PermissionRequest, context: dict[str, str]) -> Perm async def main() -> None: print("=== GitHub Copilot Agent with Shell Permissions ===\n") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful assistant that can execute shell commands.", - default_options={"on_permission_request": approve_and_log}, + default_options=GitHubCopilotOptions(on_permission_request=approve_and_log), ) async with agent: diff --git a/python/samples/02-agents/providers/github_copilot/github_copilot_with_url.py b/python/samples/02-agents/providers/github_copilot/github_copilot_with_url.py index 61fa90cd5d4..eb3edc5296b 100644 --- a/python/samples/02-agents/providers/github_copilot/github_copilot_with_url.py +++ b/python/samples/02-agents/providers/github_copilot/github_copilot_with_url.py @@ -13,7 +13,7 @@ import asyncio -from agent_framework.github import GitHubCopilotAgent +from agent_framework.github import GitHubCopilotAgent, GitHubCopilotOptions from copilot.generated.rpc import PermissionDecisionUserNotAvailable from copilot.session import PermissionHandler, PermissionRequestResult from copilot.session_events import PermissionRequest @@ -33,9 +33,9 @@ def approve_and_log(request: PermissionRequest, context: dict[str, str]) -> Perm async def main() -> None: print("=== GitHub Copilot Agent with URL Fetching ===\n") - agent = GitHubCopilotAgent( + agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent( instructions="You are a helpful assistant that can fetch and summarize web content.", - default_options={"on_permission_request": approve_and_log}, + default_options=GitHubCopilotOptions(on_permission_request=approve_and_log), ) async with agent: diff --git a/python/samples/02-agents/providers/ollama/ollama_chat_client.py b/python/samples/02-agents/providers/ollama/ollama_chat_client.py index adf2d3818e3..869bd309492 100644 --- a/python/samples/02-agents/providers/ollama/ollama_chat_client.py +++ b/python/samples/02-agents/providers/ollama/ollama_chat_client.py @@ -3,7 +3,7 @@ import asyncio from datetime import datetime -from agent_framework import Message, tool +from agent_framework import ChatOptions, Message, tool from agent_framework.ollama import OllamaChatClient from dotenv import load_dotenv @@ -40,12 +40,12 @@ async def main() -> None: print(f"User: {message}") if stream: print("Assistant: ", end="") - async for chunk in client.get_response(messages, tools=get_time, stream=True): + async for chunk in client.get_response(messages=messages, options=ChatOptions(tools=get_time), stream=True): if str(chunk): print(str(chunk), end="") print("") else: - response = await client.get_response(messages, tools=get_time) + response = await client.get_response(messages=messages, options=ChatOptions(tools=get_time)) print(f"Assistant: {response}") diff --git a/python/samples/02-agents/providers/openai/client_streaming_image_generation.py b/python/samples/02-agents/providers/openai/client_streaming_image_generation.py index 412e6f8e6a6..5b3f07742a9 100644 --- a/python/samples/02-agents/providers/openai/client_streaming_image_generation.py +++ b/python/samples/02-agents/providers/openai/client_streaming_image_generation.py @@ -78,7 +78,8 @@ async def main(): extension = image_output.media_type.split("/")[-1] # Save images with correct extension filename = output_dir / f"image{image_count}.{extension}" - await save_image_from_data_uri(image_output.uri, str(filename)) + if image_output.uri is not None: + await save_image_from_data_uri(image_output.uri, str(filename)) image_count += 1 # Summary print("\n Summary:") diff --git a/python/samples/02-agents/providers/openai/client_with_hosted_mcp.py b/python/samples/02-agents/providers/openai/client_with_hosted_mcp.py index f9cc0c71489..f48a72e1b07 100644 --- a/python/samples/02-agents/providers/openai/client_with_hosted_mcp.py +++ b/python/samples/02-agents/providers/openai/client_with_hosted_mcp.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv if TYPE_CHECKING: - from agent_framework import AgentSession, SupportsAgentRun + from agent_framework import AgentSession # Load environment variables from .env file load_dotenv() @@ -21,7 +21,7 @@ """ -async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun"): +async def handle_approvals_without_session(query: str, agent: Agent[Any]): """When we don't have a session, we need to ensure we return with the input, approval request and approval.""" from agent_framework import Message @@ -29,6 +29,8 @@ async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun" while len(result.user_input_requests) > 0: new_inputs: list[Any] = [query] for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" @@ -46,14 +48,16 @@ async def handle_approvals_without_session(query: str, agent: "SupportsAgentRun" return result -async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", session: "AgentSession"): +async def handle_approvals_with_session(query: str, agent: Agent[Any], session: "AgentSession"): """Here we let the session deal with the previous responses, and we just rerun with the approval.""" - from agent_framework import Message + from agent_framework import ChatOptions, Message - result = await agent.run(query, session=session, options={"store": True}) + result = await agent.run(query, session=session, options=ChatOptions(store=True)) while len(result.user_input_requests) > 0: new_input: list[Any] = [] for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" @@ -65,23 +69,25 @@ async def handle_approvals_with_session(query: str, agent: "SupportsAgentRun", s contents=[user_input_needed.to_function_approval_response(user_approval.lower() == "y")], ) ) - result = await agent.run(new_input, session=session, options={"store": True}) + result = await agent.run(new_input, session=session, options=ChatOptions(store=True)) return result -async def handle_approvals_with_session_streaming(query: str, agent: "SupportsAgentRun", session: "AgentSession"): +async def handle_approvals_with_session_streaming(query: str, agent: Agent[Any], session: "AgentSession"): """Here we let the session deal with the previous responses, and we just rerun with the approval.""" - from agent_framework import Message + from agent_framework import ChatOptions, Message new_input: list[Message | str] = [query] new_input_added = True while new_input_added: new_input_added = False - async for update in agent.run(new_input, session=session, stream=True, options={"store": True}): + async for update in agent.run(new_input, session=session, stream=True, options=ChatOptions(store=True)): if update.user_input_requests: # Reset input to only contain new approval responses for the next iteration new_input = [] for user_input_needed in update.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"User Input Request for function from {agent.name}: {user_input_needed.function_call.name}" f" with arguments: {user_input_needed.function_call.arguments}" diff --git a/python/samples/02-agents/providers/openai/client_with_local_shell.py b/python/samples/02-agents/providers/openai/client_with_local_shell.py index bd59d6bc8aa..be2698806bd 100644 --- a/python/samples/02-agents/providers/openai/client_with_local_shell.py +++ b/python/samples/02-agents/providers/openai/client_with_local_shell.py @@ -74,6 +74,8 @@ async def run_with_approvals(query: str, agent: Agent) -> Any: next_input: list[Any] = [query] rejected = False for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"\nShell request: {user_input_needed.function_call.name}" f"\nArguments: {user_input_needed.function_call.arguments}" diff --git a/python/samples/02-agents/response_stream.py b/python/samples/02-agents/response_stream.py index ad6d3df3ba9..365e4c2bcda 100644 --- a/python/samples/02-agents/response_stream.py +++ b/python/samples/02-agents/response_stream.py @@ -200,12 +200,12 @@ def uppercase_hook(update: ChatResponseUpdate) -> ChatResponseUpdate: """Hook that converts text to uppercase.""" if update.text: return ChatResponseUpdate( - contents=[Content.from_text(update.text.upper())], role=update.role, response_id=update.response_id + contents=[Content.from_text(update.text.upper())], role=None, response_id=update.response_id ) return update # Pass transform_hooks directly to constructor - stream3 = ResponseStream( + stream3: ResponseStream[ChatResponseUpdate, ChatResponse] = ResponseStream( generate_updates(), finalizer=combine_updates, transform_hooks=[counting_hook, uppercase_hook], # First counts, then uppercases @@ -262,7 +262,7 @@ def wrap_in_quotes_hook(response: ChatResponse) -> ChatResponse: return response # Finalizer converts updates to response, then result hooks transform it - stream5 = ResponseStream( + stream5: ResponseStream[ChatResponseUpdate, ChatResponse] = ResponseStream( generate_updates(), finalizer=combine_updates, result_hooks=[add_metadata_hook, wrap_in_quotes_hook], # First adds metadata, then wraps in quotes @@ -285,7 +285,7 @@ def to_agent_format(update: ChatResponseUpdate) -> ChatResponseUpdate: """Map ChatResponseUpdate to agent format (simulated transformation).""" # In real code, this would convert to AgentResponseUpdate return ChatResponseUpdate( - contents=[Content.from_text(f"[AGENT] {update.text}")], role=update.role, response_id=update.response_id + contents=[Content.from_text(f"[AGENT] {update.text}")], role=None, response_id=update.response_id ) def to_agent_response(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: @@ -337,7 +337,7 @@ def add_stats_to_response(response: ChatResponse) -> ChatResponse: return response # All hooks can be passed via constructor - full_stream = ResponseStream( + full_stream: ResponseStream[ChatResponseUpdate, ChatResponse] = ResponseStream( generate_updates(), finalizer=combine_updates, transform_hooks=[track_stats], diff --git a/python/samples/02-agents/security/repo_confidentiality_example.py b/python/samples/02-agents/security/repo_confidentiality_example.py index d81bd47a18c..83a991f1844 100644 --- a/python/samples/02-agents/security/repo_confidentiality_example.py +++ b/python/samples/02-agents/security/repo_confidentiality_example.py @@ -62,7 +62,7 @@ # Simulated Repository Data # ============================================================================= -REPOSITORIES = { +REPOSITORIES: dict[str, Any] = { "public-docs": { "visibility": "public", "files": { diff --git a/python/samples/02-agents/skills/file_based_skill/file_based_skill.py b/python/samples/02-agents/skills/file_based_skill/file_based_skill.py index 35a9af9d660..43e98dcfec6 100644 --- a/python/samples/02-agents/skills/file_based_skill/file_based_skill.py +++ b/python/samples/02-agents/skills/file_based_skill/file_based_skill.py @@ -20,7 +20,7 @@ if _SKILLS_ROOT not in sys.path: sys.path.insert(0, _SKILLS_ROOT) -from subprocess_script_runner import subprocess_script_runner # noqa: E402 +from subprocess_script_runner import subprocess_script_runner # pyrefly: ignore[missing-import] # noqa: E402 """ File-Based Agent Skills diff --git a/python/samples/02-agents/skills/mixed_skills/mixed_skills.py b/python/samples/02-agents/skills/mixed_skills/mixed_skills.py index 2f89074cbd5..82a203d72ed 100644 --- a/python/samples/02-agents/skills/mixed_skills/mixed_skills.py +++ b/python/samples/02-agents/skills/mixed_skills/mixed_skills.py @@ -33,7 +33,7 @@ if _SKILLS_ROOT not in sys.path: sys.path.insert(0, _SKILLS_ROOT) -from subprocess_script_runner import subprocess_script_runner # noqa: E402 +from subprocess_script_runner import subprocess_script_runner # pyrefly: ignore[missing-import] # noqa: E402 """ Mixed Skills — Code, class, and file skills in a single agent diff --git a/python/samples/02-agents/skills/script_approval/script_approval.py b/python/samples/02-agents/skills/script_approval/script_approval.py index 8687bf68677..179b8008182 100644 --- a/python/samples/02-agents/skills/script_approval/script_approval.py +++ b/python/samples/02-agents/skills/script_approval/script_approval.py @@ -96,9 +96,11 @@ async def main() -> None: # maintained automatically — just send the approval response) while result.user_input_requests: for request in result.user_input_requests: + if request.function_call is None: + continue print("\nApproval needed:") - print(f" Function: {request.function_call.name}") # type: ignore[union-attr] - print(f" Arguments: {request.function_call.arguments}") # type: ignore[union-attr] + print(f" Function: {request.function_call.name}") + print(f" Arguments: {request.function_call.arguments}") # In a real application, prompt the user here approved = True # Change to False to see rejection diff --git a/python/samples/02-agents/skills/skill_filtering/skill_filtering.py b/python/samples/02-agents/skills/skill_filtering/skill_filtering.py index 55eea099d65..120f179e12f 100644 --- a/python/samples/02-agents/skills/skill_filtering/skill_filtering.py +++ b/python/samples/02-agents/skills/skill_filtering/skill_filtering.py @@ -21,7 +21,7 @@ if _SKILLS_ROOT not in sys.path: sys.path.insert(0, _SKILLS_ROOT) -from subprocess_script_runner import subprocess_script_runner # noqa: E402 +from subprocess_script_runner import subprocess_script_runner # pyrefly: ignore[missing-import] # noqa: E402 """ Skill Filtering — Using FilteringSkillsSource with file-based skills diff --git a/python/samples/02-agents/tools/function_tool_with_approval.py b/python/samples/02-agents/tools/function_tool_with_approval.py index 42f1da19ea5..ccb6dc504b4 100644 --- a/python/samples/02-agents/tools/function_tool_with_approval.py +++ b/python/samples/02-agents/tools/function_tool_with_approval.py @@ -58,6 +58,8 @@ async def handle_approvals(query: str, agent: "SupportsAgentRun") -> AgentRespon new_inputs: list[Any] = [query] for user_input_needed in result.user_input_requests: + if user_input_needed.function_call is None: + continue print( f"\nUser Input Request for function from {agent.name}:" f"\n Function: {user_input_needed.function_call.name}" @@ -108,6 +110,8 @@ async def handle_approvals_streaming(query: str, agent: "SupportsAgentRun") -> N new_inputs: list[Any] = [query] for user_input_needed in user_input_requests: + if user_input_needed.function_call is None: + continue print( f"\n\nUser Input Request for function from {agent.name}:" f"\n Function: {user_input_needed.function_call.name}" diff --git a/python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py b/python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py index 2f09420a4a4..31957d4709a 100644 --- a/python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py +++ b/python/samples/02-agents/tools/function_tool_with_approval_and_sessions.py @@ -48,6 +48,8 @@ async def approval_example() -> None: # Check for approval requests if result.user_input_requests: for request in result.user_input_requests: + if request.function_call is None: + continue print("\nApproval needed:") print(f" Function: {request.function_call.name}") print(f" Arguments: {request.function_call.arguments}") @@ -82,6 +84,8 @@ async def rejection_example() -> None: if result.user_input_requests: for request in result.user_input_requests: + if request.function_call is None: + continue print("\nApproval needed:") print(f" Function: {request.function_call.name}") print(f" Arguments: {request.function_call.arguments}") diff --git a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py index 3600c8ce339..8c28c254f3b 100644 --- a/python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py +++ b/python/samples/03-workflows/agents/azure_ai_agents_with_shared_session.py @@ -92,9 +92,6 @@ async def main() -> None: result = await workflow.run( "Write a tagline for a budget-friendly eBike.", - # Keyword arguments will be passed to each agent call. - # Setting store=False to avoid storing messages in the service for this example. - options={"store": False}, ) # The final state should be IDLE since the workflow no longer has messages to diff --git a/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py index 1622491d12a..f6e8fc3a2c4 100644 --- a/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/03-workflows/agents/workflow_as_agent_human_in_the_loop.py @@ -29,7 +29,7 @@ handler, response_handler, ) -from workflow_as_agent_reflection_pattern import ( # noqa: E402 +from workflow_as_agent_reflection_pattern import ( # pyrefly: ignore[missing-import] # noqa: E402 ReviewRequest, ReviewResponse, Worker, diff --git a/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py b/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py index 7c30106cc80..fe3e1e0e396 100644 --- a/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py +++ b/python/samples/03-workflows/agents/workflow_as_agent_kwargs.py @@ -140,7 +140,7 @@ async def main() -> None: print("\n===== Streaming Response =====") async for update in workflow_agent.run( "Please get my user data and then call the users API endpoint.", - additional_function_arguments={"custom_data": custom_data, "user_token": user_token}, + function_invocation_kwargs={"custom_data": custom_data, "user_token": user_token}, stream=True, ): if update.text: diff --git a/python/samples/03-workflows/composition/sub_workflow_kwargs.py b/python/samples/03-workflows/composition/sub_workflow_kwargs.py index d3991e0218e..e5a1c1d4def 100644 --- a/python/samples/03-workflows/composition/sub_workflow_kwargs.py +++ b/python/samples/03-workflows/composition/sub_workflow_kwargs.py @@ -140,8 +140,7 @@ async def main() -> None: async for event in outer_workflow.run( "Please fetch my profile data and then call the users service.", stream=True, - user_token=user_token, - service_config=service_config, + function_invocation_kwargs={"user_token": user_token, "service_config": service_config}, ): if event.type == "output": output_data = event.data diff --git a/python/samples/03-workflows/control-flow/edge_condition.py b/python/samples/03-workflows/control-flow/edge_condition.py index 8cc9d1990ea..a57da1a126c 100644 --- a/python/samples/03-workflows/control-flow/edge_condition.py +++ b/python/samples/03-workflows/control-flow/edge_condition.py @@ -14,7 +14,8 @@ WorkflowContext, # Per-run context and event bus executor, # Decorator to declare a Python function as a workflow executor ) -from agent_framework.foundry import FoundryChatClient # Thin client wrapper for Azure OpenAI chat models +from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions # Thin client wrapper for Azure OpenAI chat models from azure.identity import AzureCliCredential # Uses your az CLI login for credentials from dotenv import load_dotenv from pydantic import BaseModel # Structured outputs for safer parsing @@ -148,7 +149,7 @@ def create_spam_detector_agent() -> Agent: "Include the original email content in email_content." ), name="spam_detection_agent", - default_options={"response_format": DetectionResult}, + default_options=OpenAIChatOptions[Any](response_format=DetectionResult), ) @@ -167,7 +168,7 @@ def create_email_assistant_agent() -> Agent: "Return JSON with a single field 'response' containing the drafted reply." ), name="email_assistant_agent", - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), ) diff --git a/python/samples/03-workflows/control-flow/intermediate_vs_terminal_outputs.py b/python/samples/03-workflows/control-flow/intermediate_vs_terminal_outputs.py index 520c17a07bf..84d9c1e8556 100644 --- a/python/samples/03-workflows/control-flow/intermediate_vs_terminal_outputs.py +++ b/python/samples/03-workflows/control-flow/intermediate_vs_terminal_outputs.py @@ -107,7 +107,9 @@ async def main() -> None: agent = workflow.as_agent("planner-agent") response = await agent.run("life, the universe, and everything") print(f" response.text (Workflow Output only): {response.text!r}") - reasoning = " | ".join(c.text for m in response.messages for c in m.contents if c.type == "text_reasoning") + reasoning = " | ".join( + c.text for m in response.messages for c in m.contents if c.type == "text_reasoning" and c.text is not None + ) print(f" reasoning content (intermediates): {reasoning!r}") # Embed the same workflow as a node inside a larger workflow via WorkflowExecutor. diff --git a/python/samples/03-workflows/control-flow/multi_selection_edge_group.py b/python/samples/03-workflows/control-flow/multi_selection_edge_group.py index ebf222d184c..3fb684a5b50 100644 --- a/python/samples/03-workflows/control-flow/multi_selection_edge_group.py +++ b/python/samples/03-workflows/control-flow/multi_selection_edge_group.py @@ -5,7 +5,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal from uuid import uuid4 from agent_framework import ( @@ -21,6 +21,7 @@ executor, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel @@ -200,7 +201,7 @@ def create_email_analysis_agent() -> Agent: "and 'reason' (string)." ), name="email_analysis_agent", - default_options={"response_format": AnalysisResultAgent}, + default_options=OpenAIChatOptions[Any](response_format=AnalysisResultAgent), ) @@ -214,7 +215,7 @@ def create_email_assistant_agent() -> Agent: ), instructions=("You are an email assistant that helps users draft responses to emails with professionalism."), name="email_assistant_agent", - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), ) @@ -228,7 +229,7 @@ def create_email_summary_agent() -> Agent: ), instructions=("You are an assistant that helps users summarize emails."), name="email_summary_agent", - default_options={"response_format": EmailSummaryModel}, + default_options=OpenAIChatOptions[Any](response_format=EmailSummaryModel), ) diff --git a/python/samples/03-workflows/control-flow/switch_case_edge_group.py b/python/samples/03-workflows/control-flow/switch_case_edge_group.py index ed183a51439..8f85e075760 100644 --- a/python/samples/03-workflows/control-flow/switch_case_edge_group.py +++ b/python/samples/03-workflows/control-flow/switch_case_edge_group.py @@ -18,7 +18,8 @@ WorkflowContext, # Per-run context and event bus executor, # Decorator to turn a function into a workflow executor ) -from agent_framework.foundry import FoundryChatClient # Thin client for Azure OpenAI chat models +from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions # Thin client for Azure OpenAI chat models from azure.identity import AzureCliCredential # Uses your az CLI login for credentials from dotenv import load_dotenv from pydantic import BaseModel # Structured outputs with validation @@ -172,7 +173,7 @@ def create_spam_detection_agent() -> Agent: "and 'reason' (string)." ), name="spam_detection_agent", - default_options={"response_format": DetectionResultAgent}, + default_options=OpenAIChatOptions[Any](response_format=DetectionResultAgent), ) @@ -186,7 +187,7 @@ def create_email_assistant_agent() -> Agent: ), instructions=("You are an email assistant that helps users draft responses to emails with professionalism."), name="email_assistant_agent", - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), ) diff --git a/python/samples/03-workflows/declarative/agent_to_function_tool/main.py b/python/samples/03-workflows/declarative/agent_to_function_tool/main.py index 54e393ee14b..7cbad9dbf7c 100644 --- a/python/samples/03-workflows/declarative/agent_to_function_tool/main.py +++ b/python/samples/03-workflows/declarative/agent_to_function_tool/main.py @@ -25,6 +25,7 @@ from agent_framework import Agent from agent_framework.declarative import WorkflowFactory from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from pydantic import BaseModel, Field @@ -213,7 +214,7 @@ async def main(): client=chat_client, name="OrderAnalysisAgent", instructions=ORDER_ANALYSIS_INSTRUCTIONS, - default_options={"response_format": OrderAnalysis}, + default_options=OpenAIChatOptions[Any](response_format=OrderAnalysis), ) # Agent registry diff --git a/python/samples/03-workflows/declarative/customer_support/main.py b/python/samples/03-workflows/declarative/customer_support/main.py index d67adebf1bc..27f5dc5f094 100644 --- a/python/samples/03-workflows/declarative/customer_support/main.py +++ b/python/samples/03-workflows/declarative/customer_support/main.py @@ -26,6 +26,7 @@ import os import uuid from pathlib import Path +from typing import Any from agent_framework import Agent from agent_framework.declarative import ( @@ -34,10 +35,11 @@ WorkflowFactory, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel, Field -from ticketing_plugin import TicketingPlugin +from ticketing_plugin import TicketingPlugin # ty: ignore[unresolved-import] # pyrefly: ignore[missing-import] logging.basicConfig(level=logging.ERROR) @@ -182,7 +184,7 @@ async def main() -> None: client=client, name="SelfServiceAgent", instructions=SELF_SERVICE_INSTRUCTIONS, - default_options={"response_format": SelfServiceResponse}, + default_options=OpenAIChatOptions[Any](response_format=SelfServiceResponse), ) ticketing_agent = Agent( @@ -190,7 +192,7 @@ async def main() -> None: name="TicketingAgent", instructions=TICKETING_INSTRUCTIONS, tools=plugin.get_functions(), - default_options={"response_format": TicketingResponse}, + default_options=OpenAIChatOptions[Any](response_format=TicketingResponse), ) routing_agent = Agent( @@ -198,7 +200,7 @@ async def main() -> None: name="TicketRoutingAgent", instructions=TICKET_ROUTING_INSTRUCTIONS, tools=[plugin.get_ticket], - default_options={"response_format": RoutingResponse}, + default_options=OpenAIChatOptions[Any](response_format=RoutingResponse), ) windows_support_agent = Agent( @@ -206,7 +208,7 @@ async def main() -> None: name="WindowsSupportAgent", instructions=WINDOWS_SUPPORT_INSTRUCTIONS, tools=[plugin.get_ticket], - default_options={"response_format": SupportResponse}, + default_options=OpenAIChatOptions[Any](response_format=SupportResponse), ) resolution_agent = Agent( @@ -221,7 +223,7 @@ async def main() -> None: name="TicketEscalationAgent", instructions=ESCALATION_INSTRUCTIONS, tools=[plugin.get_ticket, plugin.send_notification], - default_options={"response_format": EscalationResponse}, + default_options=OpenAIChatOptions[Any](response_format=EscalationResponse), ) # Agent registry for lookup diff --git a/python/samples/03-workflows/declarative/deep_research/main.py b/python/samples/03-workflows/declarative/deep_research/main.py index 2eccb44da0a..b67a17dea0a 100644 --- a/python/samples/03-workflows/declarative/deep_research/main.py +++ b/python/samples/03-workflows/declarative/deep_research/main.py @@ -24,10 +24,12 @@ import asyncio import os from pathlib import Path +from typing import Any from agent_framework import Agent from agent_framework.declarative import WorkflowFactory from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel, Field @@ -148,7 +150,7 @@ async def main() -> None: client=client, name="ManagerAgent", instructions=MANAGER_INSTRUCTIONS, - default_options={"response_format": ManagerResponse}, + default_options=OpenAIChatOptions[Any](response_format=ManagerResponse), ) summary_agent = Agent( diff --git a/python/samples/03-workflows/declarative/invoke_foundry_toolbox_mcp/main.py b/python/samples/03-workflows/declarative/invoke_foundry_toolbox_mcp/main.py index 9592919e33b..64c6a965675 100644 --- a/python/samples/03-workflows/declarative/invoke_foundry_toolbox_mcp/main.py +++ b/python/samples/03-workflows/declarative/invoke_foundry_toolbox_mcp/main.py @@ -29,7 +29,10 @@ from agent_framework.foundry import FoundryChatClient from azure.core.credentials import TokenCredential from azure.identity import AzureCliCredential, get_bearer_token_provider -from toolbox_provisioning import FOUNDRY_FEATURES_HEADERS, create_sample_toolbox +from toolbox_provisioning import ( # ty: ignore[unresolved-import] # pyrefly: ignore[missing-import] + FOUNDRY_FEATURES_HEADERS, + create_sample_toolbox, +) AGENT_NAME = "FoundryToolboxMcpAgent" TOOLBOX_NAME = "declarative_foundry_toolbox_mcp" diff --git a/python/samples/03-workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/03-workflows/human-in-the-loop/guessing_game_with_human_input.py index 0abba476fda..f72d9e68049 100644 --- a/python/samples/03-workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/03-workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -4,6 +4,7 @@ import os from collections.abc import AsyncIterable from dataclasses import dataclass +from typing import Any from agent_framework import ( Agent, @@ -19,6 +20,7 @@ response_handler, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel @@ -211,7 +213,7 @@ async def main() -> None: "No explanations or additional text." ), # response_format enforces that the model produces JSON compatible with GuessOutput. - default_options={"response_format": GuessOutput}, + default_options=OpenAIChatOptions[Any](response_format=GuessOutput), ) turn_manager = TurnManager(id="turn_manager") diff --git a/python/samples/03-workflows/state-management/state_with_agents.py b/python/samples/03-workflows/state-management/state_with_agents.py index 5a0d96aabbb..f79d92715e4 100644 --- a/python/samples/03-workflows/state-management/state_with_agents.py +++ b/python/samples/03-workflows/state-management/state_with_agents.py @@ -17,6 +17,7 @@ executor, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from azure.identity import AzureCliCredential from dotenv import load_dotenv from pydantic import BaseModel @@ -172,7 +173,7 @@ def create_spam_detection_agent() -> Agent: "You are a spam detection assistant that identifies spam emails. " "Always return JSON with fields is_spam (bool) and reason (string)." ), - default_options={"response_format": DetectionResultAgent}, + default_options=OpenAIChatOptions[Any](response_format=DetectionResultAgent), # response_format enforces structured JSON from each agent. name="spam_detection_agent", ) @@ -191,7 +192,7 @@ def create_email_assistant_agent() -> Agent: "Return JSON with a single field 'response' containing the drafted reply." ), # response_format enforces structured JSON from each agent. - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), name="email_assistant_agent", ) diff --git a/python/samples/04-hosting/a2a/a2a_server.py b/python/samples/04-hosting/a2a/a2a_server.py index 185d9da0481..a03eea2a288 100644 --- a/python/samples/04-hosting/a2a/a2a_server.py +++ b/python/samples/04-hosting/a2a/a2a_server.py @@ -8,7 +8,7 @@ from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes from a2a.server.tasks import InMemoryTaskStore -from agent_definitions import AGENT_CARD_FACTORIES, AGENT_FACTORIES +from agent_definitions import AGENT_CARD_FACTORIES, AGENT_FACTORIES # pyrefly: ignore[missing-import] from agent_framework.a2a import A2AExecutor from agent_framework.foundry import FoundryChatClient from azure.identity import AzureCliCredential diff --git a/python/samples/04-hosting/a2a/agent_definitions.py b/python/samples/04-hosting/a2a/agent_definitions.py index 79f1f1b1ee6..32861fed77f 100644 --- a/python/samples/04-hosting/a2a/agent_definitions.py +++ b/python/samples/04-hosting/a2a/agent_definitions.py @@ -11,7 +11,7 @@ from a2a.types import AgentCapabilities, AgentCard, AgentInterface, AgentSkill from agent_framework import Agent from agent_framework.foundry import FoundryChatClient -from invoice_data import query_by_invoice_id, query_by_transaction_id, query_invoices +from invoice_data import query_by_invoice_id, query_by_transaction_id, query_invoices # pyrefly: ignore[missing-import] # --------------------------------------------------------------------------- # Agent instructions diff --git a/python/samples/04-hosting/azure_functions/03_reliable_streaming/function_app.py b/python/samples/04-hosting/azure_functions/03_reliable_streaming/function_app.py index f74f081aadb..6b2bdf0bb33 100644 --- a/python/samples/04-hosting/azure_functions/03_reliable_streaming/function_app.py +++ b/python/samples/04-hosting/azure_functions/03_reliable_streaming/function_app.py @@ -31,8 +31,8 @@ from agent_framework.foundry import FoundryChatClient from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv -from redis_stream_response_handler import RedisStreamResponseHandler, StreamChunk -from tools import get_local_events, get_weather_forecast +from redis_stream_response_handler import RedisStreamResponseHandler, StreamChunk # pyrefly: ignore[missing-import] +from tools import get_local_events, get_weather_forecast # pyrefly: ignore[missing-import] # Load environment variables from .env file load_dotenv() @@ -298,8 +298,8 @@ async def _stream_to_client( def _format_chunk(chunk: StreamChunk, use_sse_format: bool) -> str: """Format a text chunk.""" if use_sse_format: - return _format_sse_event("message", chunk.text, chunk.entry_id) - return chunk.text + return _format_sse_event("message", chunk.text or "", chunk.entry_id) + return chunk.text or "" def _format_end_of_stream(entry_id: str, use_sse_format: bool) -> str: diff --git a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py index f3e77db390f..1833277ed98 100644 --- a/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/07_single_agent_orchestration_hitl/function_app.py @@ -133,7 +133,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext) winner = yield context.task_any([approval_task, timeout_task]) if winner == approval_task: - timeout_task.cancel() # type: ignore[attr-defined] + timeout_task.cancel() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] approval_payload = _parse_human_approval(approval_task.result) if approval_payload.approved: diff --git a/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py index e4c5110ac19..d8f2a8b5931 100644 --- a/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py +++ b/python/samples/04-hosting/azure_functions/09_workflow_shared_state/function_app.py @@ -35,6 +35,7 @@ executor, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from agent_framework_azurefunctions import AgentFunctionApp from azure.identity.aio import AzureCliCredential from pydantic import BaseModel, ValidationError @@ -199,7 +200,7 @@ def _create_workflow() -> Workflow: "You are a spam detection assistant that identifies spam emails. " "Always return JSON with fields is_spam (bool) and reason (string)." ), - default_options={"response_format": DetectionResultAgent}, + default_options=OpenAIChatOptions[Any](response_format=DetectionResultAgent), name="spam_detection_agent", ) @@ -209,7 +210,7 @@ def _create_workflow() -> Workflow: "You are an email assistant that helps users draft responses to emails with professionalism. " "Return JSON with a single field 'response' containing the drafted reply." ), - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), name="email_assistant_agent", ) diff --git a/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py index 435502e7539..d451bd9f6cc 100644 --- a/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py +++ b/python/samples/04-hosting/azure_functions/10_workflow_no_shared_state/function_app.py @@ -38,6 +38,7 @@ handler, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from agent_framework_azurefunctions import AgentFunctionApp from azure.identity.aio import AzureCliCredential from pydantic import BaseModel, ValidationError @@ -123,7 +124,7 @@ async def handle_spam_result( try: spam_result = SpamDetectionResult.model_validate_json(text) except ValidationError: - spam_result = SpamDetectionResult(is_spam=True, reason="Invalid JSON from agent") + spam_result = SpamDetectionResult(is_spam=True, reason="Invalid JSON from agent", confidence=0.0) message = f"Email marked as spam: {spam_result.reason}" await ctx.yield_output(message) @@ -170,14 +171,14 @@ def _create_workflow() -> Workflow: client=chat_client, name=SPAM_AGENT_NAME, instructions=SPAM_DETECTION_INSTRUCTIONS, - default_options={"response_format": SpamDetectionResult}, + default_options=OpenAIChatOptions[Any](response_format=SpamDetectionResult), ) email_agent = Agent( client=chat_client, name=EMAIL_AGENT_NAME, instructions=EMAIL_ASSISTANT_INSTRUCTIONS, - default_options={"response_format": EmailResponse}, + default_options=OpenAIChatOptions[Any](response_format=EmailResponse), ) # Executors diff --git a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py index 0669d95e7b1..b093b188331 100644 --- a/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py +++ b/python/samples/04-hosting/azure_functions/11_workflow_parallel/function_app.py @@ -43,7 +43,7 @@ handler, ) from agent_framework.azure import AgentFunctionApp -from agent_framework.openai import OpenAIChatCompletionClient +from agent_framework.openai import OpenAIChatCompletionClient, OpenAIChatCompletionOptions from azure.identity.aio import AzureCliCredential, get_bearer_token_provider from pydantic import BaseModel from typing_extensions import Never @@ -375,7 +375,7 @@ def _create_workflow() -> Workflow: "Return JSON with fields: sentiment (positive/negative/neutral), " "confidence (0.0-1.0), and explanation (brief reasoning)." ), - default_options={"response_format": SentimentResult}, + default_options=OpenAIChatCompletionOptions[Any](response_format=SentimentResult), ) keyword_agent = Agent( @@ -386,7 +386,7 @@ def _create_workflow() -> Workflow: "from the given text. Return JSON with fields: keywords (list of strings), " "and categories (list of topic categories)." ), - default_options={"response_format": KeywordResult}, + default_options=OpenAIChatCompletionOptions[Any](response_format=KeywordResult), ) # Create summary agent for Pattern 3 (mixed parallel) @@ -398,7 +398,7 @@ def _create_workflow() -> Workflow: "provide a concise summary. Return JSON with fields: summary (brief text), " "and key_points (list of main takeaways)." ), - default_options={"response_format": SummaryResult}, + default_options=OpenAIChatCompletionOptions[Any](response_format=SummaryResult), ) # Create executor instances diff --git a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py index e1f9389a6da..6542784b320 100644 --- a/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py +++ b/python/samples/04-hosting/azure_functions/12_workflow_hitl/function_app.py @@ -42,6 +42,7 @@ response_handler, ) from agent_framework.foundry import FoundryChatClient +from agent_framework.openai import OpenAIChatOptions from agent_framework_azurefunctions import AgentFunctionApp from azure.identity.aio import AzureCliCredential from pydantic import BaseModel, ValidationError @@ -379,7 +380,7 @@ def _create_workflow() -> Workflow: client=chat_client, name=CONTENT_ANALYZER_AGENT_NAME, instructions=CONTENT_ANALYZER_INSTRUCTIONS, - default_options={"response_format": ContentAnalysisResult}, + default_options=OpenAIChatOptions[Any](response_format=ContentAnalysisResult), ) # Create executors diff --git a/python/samples/04-hosting/durabletask/01_single_agent/sample.py b/python/samples/04-hosting/durabletask/01_single_agent/sample.py index 86a74c73e29..7ba2727dd86 100644 --- a/python/samples/04-hosting/durabletask/01_single_agent/sample.py +++ b/python/samples/04-hosting/durabletask/01_single_agent/sample.py @@ -15,9 +15,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] # Configure logging (must be after imports to override their basicConfig) logging.basicConfig(level=logging.INFO, force=True) diff --git a/python/samples/04-hosting/durabletask/02_multi_agent/sample.py b/python/samples/04-hosting/durabletask/02_multi_agent/sample.py index e847abbd645..c5128771141 100644 --- a/python/samples/04-hosting/durabletask/02_multi_agent/sample.py +++ b/python/samples/04-hosting/durabletask/02_multi_agent/sample.py @@ -15,9 +15,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] # Configure logging logging.basicConfig(level=logging.INFO, force=True) diff --git a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py index caded2a17fd..2232ba3cc6d 100644 --- a/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py +++ b/python/samples/04-hosting/durabletask/03_single_agent_streaming/client.py @@ -24,7 +24,7 @@ from azure.identity import AzureCliCredential from dotenv import load_dotenv from durabletask.azuremanaged.client import DurableTaskSchedulerClient -from redis_stream_response_handler import RedisStreamResponseHandler +from redis_stream_response_handler import RedisStreamResponseHandler # pyrefly: ignore[missing-import] # Load environment variables from .env file load_dotenv() diff --git a/python/samples/04-hosting/durabletask/03_single_agent_streaming/sample.py b/python/samples/04-hosting/durabletask/03_single_agent_streaming/sample.py index 0c6865b0125..d547bf3e4d9 100644 --- a/python/samples/04-hosting/durabletask/03_single_agent_streaming/sample.py +++ b/python/samples/04-hosting/durabletask/03_single_agent_streaming/sample.py @@ -17,9 +17,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] # Configure logging (must be after imports to override their basicConfig) logging.basicConfig(level=logging.INFO, force=True) diff --git a/python/samples/04-hosting/durabletask/03_single_agent_streaming/worker.py b/python/samples/04-hosting/durabletask/03_single_agent_streaming/worker.py index fdb4e1b0c1d..b07d1318f3e 100644 --- a/python/samples/04-hosting/durabletask/03_single_agent_streaming/worker.py +++ b/python/samples/04-hosting/durabletask/03_single_agent_streaming/worker.py @@ -29,8 +29,8 @@ from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential from dotenv import load_dotenv from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -from redis_stream_response_handler import RedisStreamResponseHandler -from tools import get_local_events, get_weather_forecast +from redis_stream_response_handler import RedisStreamResponseHandler # pyrefly: ignore[missing-import] +from tools import get_local_events, get_weather_forecast # pyrefly: ignore[missing-import] # Load environment variables from .env file load_dotenv() diff --git a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/sample.py b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/sample.py index 529052315da..3afcc454bea 100644 --- a/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/sample.py +++ b/python/samples/04-hosting/durabletask/04_single_agent_orchestration_chaining/sample.py @@ -20,9 +20,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] # Configure logging logging.basicConfig(level=logging.INFO, force=True) diff --git a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/sample.py b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/sample.py index 666c967bc72..bd482a10159 100644 --- a/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/sample.py +++ b/python/samples/04-hosting/durabletask/05_multi_agent_orchestration_concurrency/sample.py @@ -17,9 +17,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] # Configure logging logging.basicConfig(level=logging.INFO, force=True) diff --git a/python/samples/04-hosting/durabletask/06_multi_agent_orchestration_conditionals/sample.py b/python/samples/04-hosting/durabletask/06_multi_agent_orchestration_conditionals/sample.py index f243b2ccb7b..1a493f2aac4 100644 --- a/python/samples/04-hosting/durabletask/06_multi_agent_orchestration_conditionals/sample.py +++ b/python/samples/04-hosting/durabletask/06_multi_agent_orchestration_conditionals/sample.py @@ -21,9 +21,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_client +from client import get_client, run_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] logging.basicConfig(level=logging.INFO, force=True) logger = logging.getLogger() diff --git a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/sample.py b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/sample.py index c99daf06f96..454897a3570 100644 --- a/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/sample.py +++ b/python/samples/04-hosting/durabletask/07_single_agent_orchestration_hitl/sample.py @@ -18,9 +18,9 @@ import logging # Import helper functions from worker and client modules -from client import get_client, run_interactive_client +from client import get_client, run_interactive_client # pyrefly: ignore[missing-import] from dotenv import load_dotenv -from worker import get_worker, setup_worker +from worker import get_worker, setup_worker # pyrefly: ignore[missing-import] logging.basicConfig(level=logging.INFO, force=True) logger = logging.getLogger() diff --git a/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py b/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py index 9d1d50b959e..35bd1745e0f 100644 --- a/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py +++ b/python/samples/04-hosting/foundry-hosted-agents/responses/using_deployed_agent.py @@ -11,7 +11,7 @@ from agent_framework.foundry import FoundryAgent from azure.ai.projects.aio import AIProjectClient from azure.ai.projects.models import VersionRefIndicator -from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv load_dotenv() diff --git a/python/scripts/task_runner.py b/python/scripts/task_runner.py index 8bd038359a1..a1331803cc4 100644 --- a/python/scripts/task_runner.py +++ b/python/scripts/task_runner.py @@ -226,3 +226,75 @@ def run_tasks( _run_sequential(work_items, task_args) else: _run_parallel(work_items, workspace_root, task_args) + + +def _run_command_subprocess( + label: str, + command: Sequence[str], + workspace_root: Path, +) -> tuple[str, int, str, str, float]: + """Run a single labelled command in ``workspace_root`` and capture its output.""" + start = time.monotonic() + result = subprocess.run(command, cwd=workspace_root, capture_output=True, text=True) + elapsed = time.monotonic() - start + return (label, result.returncode, result.stdout, result.stderr, elapsed) + + +def run_command_items( + command_items: list[tuple[str, Sequence[str]]], + workspace_root: Path, + *, + sequential: bool = False, +) -> None: + """Run labelled commands using the same model as :func:`run_tasks`. + + A single command streams its output live; multiple commands run in parallel + subprocesses with captured output and a ``✓``/``✗`` summary, mirroring the + pyright fan-out presentation. + """ + if not command_items: + print("[yellow]No commands to run[/yellow]") + return + + if sequential or len(command_items) == 1: + for label, command in command_items: + print(f"[cyan]>> {label}[/cyan]") + result = subprocess.run(command, cwd=workspace_root) + if result.returncode: + sys.exit(result.returncode) + return + + max_workers = min(len(command_items), os.cpu_count() or 4) + failures: list[tuple[str, str, str]] = [] + completed = 0 + total = len(command_items) + + print(f"[cyan]Running {total} task(s) in parallel (max {max_workers} workers)...[/cyan]") + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_run_command_subprocess, label, command, workspace_root): label + for label, command in command_items + } + for future in concurrent.futures.as_completed(futures): + label, returncode, stdout, stderr, elapsed = future.result() + completed += 1 + progress = f"[{completed}/{total}]" + if returncode == 0: + print(f" [green]✓[/green] {progress} {label} ({elapsed:.1f}s)") + else: + print(f" [red]✗[/red] {progress} {label} ({elapsed:.1f}s)") + failures.append((label, stdout, stderr)) + + if failures: + print(f"\n[red]{len(failures)} task(s) failed:[/red]") + for label, stdout, stderr in failures: + print(f"\n[red]{'=' * 60}[/red]") + print(f"[red]FAILED: {label}[/red]") + if stdout.strip(): + print(stdout) + if stderr.strip(): + sys.stderr.write(stderr) + sys.exit(1) + + print(f"\n[green]All {total} task(s) passed ✓[/green]") diff --git a/python/scripts/workspace_poe_tasks.py b/python/scripts/workspace_poe_tasks.py index d9e55e001d3..4d98401c456 100644 --- a/python/scripts/workspace_poe_tasks.py +++ b/python/scripts/workspace_poe_tasks.py @@ -10,7 +10,6 @@ from __future__ import annotations import argparse -import os import subprocess import sys from dataclasses import dataclass @@ -20,7 +19,13 @@ from packaging.specifiers import SpecifierSet from packaging.version import Version from rich import print -from task_runner import build_work_items, discover_projects, project_filter_matches, run_tasks +from task_runner import ( + build_work_items, + discover_projects, + project_filter_matches, + run_command_items, + run_tasks, +) WORKSPACE_ROOT = Path(__file__).resolve().parent.parent WORKSPACE_PYPROJECT = WORKSPACE_ROOT / "pyproject.toml" @@ -99,6 +104,11 @@ def add_cov_option(command: argparse.ArgumentParser) -> None: add_project_option(mypy) add_all_option(mypy) + test_typing = subparsers.add_parser("test-typing") + add_project_option(test_typing) + add_all_option(test_typing) + add_samples_option(test_typing) + typing = subparsers.add_parser("typing") add_project_option(typing) add_all_option(typing) @@ -116,6 +126,7 @@ def add_cov_option(command: argparse.ArgumentParser) -> None: prek_check.add_argument("files", nargs="*", default=["."], help="Files reported by pre-commit.") subparsers.add_parser("ci-mypy") + subparsers.add_parser("ci-test-typing") return parser.parse_known_args(argv) @@ -312,19 +323,170 @@ def run_aggregate_pyright(project_pattern: str, extra_args: list[str]) -> None: run_command(["uv", "run", "pyright", *extra_args, *project_paths]) -def run_aggregate_mypy(project_pattern: str, extra_args: list[str]) -> None: - """Run a single MyPy sweep across the selected project import roots.""" +# Type checkers that run over tests (and, where supported, samples). Pyright is the strict +# SOURCE-code checker, and ALSO runs over tests + samples in a relaxed ``basic`` profile +# (see pyrightconfig.tests.json / pyrightconfig.samples.json). mypy/pyrefly/ty/zuban +# exercise the public API the way users do. All five run by default and gate CI. ``zuban`` +# is the strictest of the mypy-compatible pair and only runs on tests (samples are +# script-style and unsupported by mypy/zuban -- see SAMPLE_TYPING_CHECKERS). +GATING_TEST_TYPING_CHECKERS = ("mypy", "pyrefly", "ty", "zuban", "pyright") + + +def _gating_checker_args() -> list[str]: + """Build ``--checker`` selectors for the CI-gating test/sample checkers.""" + args: list[str] = [] + for checker in GATING_TEST_TYPING_CHECKERS: + args.extend(["--checker", checker]) + return args + + +# Samples that are intentionally excluded from type checking (migrations, generated, etc.). +SAMPLE_TYPING_EXCLUDES = ( + "autogen-migration", + "semantic-kernel-migration", + "autogen", + "demos", + "_to_delete", + "05-end-to-end", + "harness", +) + + +def _mypy_command(paths: list[str], *, samples: bool) -> list[str]: + command = [ + "uv", + "run", + "mypy", + "--config-file", + "pyproject.toml", + "--explicit-package-bases", + "--namespace-packages", + ] + if samples: + for excluded in SAMPLE_TYPING_EXCLUDES: + command.extend(["--exclude", excluded]) + command.extend(paths) + return command + + +def _zuban_command(paths: list[str], *, samples: bool) -> list[str]: + command = ["uv", "run", "zuban", "mypy", "--config-file", "pyproject.toml"] + if samples: + for excluded in SAMPLE_TYPING_EXCLUDES: + command.extend(["--exclude", excluded]) + command.extend(paths) + return command + + +def _pyrefly_command(paths: list[str], *, samples: bool) -> list[str]: + config = "pyrefly.samples.toml" if samples else "pyrefly.toml" + command = ["uv", "run", "pyrefly", "check", "-c", config] + if samples: + for excluded in SAMPLE_TYPING_EXCLUDES: + command.extend(["--project-excludes", f"**/{excluded}/**"]) + command.extend(paths) + return command + + +def _ty_command(paths: list[str], *, samples: bool) -> list[str]: + command = ["uv", "run", "ty", "check"] + if samples: + command.extend(["--config-file", "ty.samples.toml"]) + for excluded in SAMPLE_TYPING_EXCLUDES: + command.extend(["--exclude", f"**/{excluded}/**"]) + command.extend(paths) + return command + + +def _pyright_command(paths: list[str], *, samples: bool) -> list[str]: + # Pyright owns source in strict mode; over tests + samples it runs in a relaxed + # ``basic`` profile via a dedicated config (see pyrightconfig.tests.json / + # pyrightconfig.samples.json). CLI paths override the config ``include``; the sample + # excludes live in the config itself (Pyright has no ``--exclude`` CLI flag). + config = sample_pyright_config() if samples else "pyrightconfig.tests.json" + return ["uv", "run", "pyright", "-p", config, *paths] + + +CHECKER_COMMANDS = { + "mypy": _mypy_command, + "zuban": _zuban_command, + "pyrefly": _pyrefly_command, + "ty": _ty_command, + "pyright": _pyright_command, +} + + +def _resolve_checkers(extra_args: list[str]) -> tuple[list[str], list[str]]: + """Split a ``--checker NAME`` selector (repeatable) from pass-through args. + + With no explicit ``--checker``, all gating checkers (mypy, pyrefly, ty, zuban, + pyright) run. A single checker can be isolated via e.g. ``--checker pyright``. + """ + checkers: list[str] = [] + passthrough: list[str] = [] + index = 0 + while index < len(extra_args): + argument = extra_args[index] + if argument == "--checker" and index + 1 < len(extra_args): + checkers.append(extra_args[index + 1]) + index += 2 + continue + passthrough.append(argument) + index += 1 + selected = checkers or list(GATING_TEST_TYPING_CHECKERS) + unknown = [name for name in selected if name not in CHECKER_COMMANDS] + if unknown: + print(f"[red]Unknown checker(s): {', '.join(unknown)}.[/red]") + raise SystemExit(2) + return selected, passthrough + + +def run_test_typing(project_pattern: str, extra_args: list[str]) -> None: + """Run the test-suite type checkers (mypy, pyrefly, ty, zuban, pyright) per package. + + Each (package, checker) pair is fanned out in parallel via the same executor + used by the pyright source fan-out, so the presentation and parallelism match. + """ + checkers, passthrough = _resolve_checkers(extra_args) projects = select_projects(project_pattern) if not projects: print("[yellow]No selected projects support the current Python version, skipping.[/yellow]") return - source_dirs = [relative_path(path) for path in collect_source_dirs(projects)] - if not source_dirs: - print("[yellow]No import roots found for the selected projects, skipping MyPy.[/yellow]") + command_items: list[tuple[str, list[str]]] = [] + for project in projects: + test_dirs = collect_test_dirs([project]) + if not test_dirs: + continue + paths = [relative_path(path) for path in test_dirs] + for checker in checkers: + command = CHECKER_COMMANDS[checker]([*paths, *passthrough], samples=False) + command_items.append((f"{checker} :: {project.name}", command)) + + run_command_items(command_items, WORKSPACE_ROOT) + + +# Checkers that work on the script-style samples tree. MyPy/zuban are excluded because +# samples are standalone scripts (numeric-prefixed dirs, duplicate filenames like main.py) +# that cannot be resolved into a module package tree without per-file invocation. Pyright +# handles the script-style tree fine and runs in the relaxed ``basic`` samples profile. +SAMPLE_TYPING_CHECKERS = ("pyrefly", "ty", "pyright") + + +def run_sample_typing(extra_args: list[str]) -> None: + """Run the sample-capable type checkers over samples/ in the relaxed/basic profile.""" + checkers, passthrough = _resolve_checkers(extra_args) + sample_checkers = [checker for checker in checkers if checker in SAMPLE_TYPING_CHECKERS] + skipped = [checker for checker in checkers if checker not in SAMPLE_TYPING_CHECKERS] + if skipped: + print(f"[yellow]Skipping {', '.join(skipped)} for samples (script-style tree is unsupported).[/yellow]") + if not sample_checkers: return - - run_command(["uv", "run", "mypy", "--config-file", "pyproject.toml", *extra_args, *source_dirs]) + command_items = [ + (f"{checker} :: samples", CHECKER_COMMANDS[checker](["samples", *passthrough], samples=True)) + for checker in sample_checkers + ] + run_command_items(command_items, WORKSPACE_ROOT) def run_aggregate_test(project_pattern: str, cov: bool, extra_args: list[str]) -> None: @@ -417,58 +579,10 @@ def run_prek_check(files: list[str]) -> None: print("[yellow]No sample files changed, skipping sample checks.[/yellow]") -def git_diff_name_only(*revisions: str) -> list[str] | None: - """Try a git diff strategy and return changed files if it succeeds.""" - result = subprocess.run( - ["git", "diff", "--name-only", *revisions, "--", "."], - cwd=WORKSPACE_ROOT, - capture_output=True, - text=True, - check=False, - ) - if result.returncode != 0: - return None - return [line for line in result.stdout.splitlines() if line] - - -def detect_ci_changed_files() -> list[str]: - """Detect changed files for change-based mypy runs.""" - base_ref = os.environ.get("GITHUB_BASE_REF") - if base_ref: - subprocess.run( - ["git", "fetch", "origin", base_ref, "--depth=1"], - cwd=WORKSPACE_ROOT, - capture_output=True, - text=True, - check=False, - ) - strategies = [ - (f"origin/{base_ref}...HEAD",), - ("FETCH_HEAD...HEAD",), - ("HEAD^...HEAD",), - ] - else: - strategies = [ - ("origin/main...HEAD",), - ("main...HEAD",), - ("HEAD~1",), - ] - - for strategy in strategies: - changed_files = git_diff_name_only(*strategy) - if changed_files is not None: - return changed_files or ["."] - - return ["."] - - -def run_ci_mypy() -> None: - """Run MyPy only where changes require it, mirroring CI behaviour.""" - changed_files = detect_ci_changed_files() - print("[cyan]Changed files for CI mypy:[/cyan]") - for file_path in changed_files: - print(f" {file_path}") - run_changed_package_tasks(["mypy"], changed_files) +def run_ci_test_typing() -> None: + """Run the gating test/sample type checkers across the workspace, mirroring CI.""" + run_test_typing("*", _gating_checker_args()) + run_sample_typing(_gating_checker_args()) def ensure_no_extra_args(command_name: str, extra_args: list[str]) -> None: @@ -594,23 +708,29 @@ def main() -> None: return if args.command == "mypy": - if args.all: - run_aggregate_mypy(args.project, extra_args) + # MyPy no longer runs on source code (Pyright owns source). The ``mypy`` task is a + # convenience alias that runs MyPy over the test suite. + run_test_typing(args.project, ["--checker", "mypy", *extra_args]) + return + + if args.command == "test-typing": + if args.samples: + if args.all or args.project != "*": + print("[red]--samples cannot be combined with --all or --package.[/red]") + raise SystemExit(2) + run_sample_typing(extra_args) return - run_fan_out(["mypy"], args.project, extra_args) + run_test_typing(args.project, extra_args) return if args.command == "typing": ensure_no_extra_args(args.command, extra_args) + # Pyright over source, then the multi-checker sweep over the tests. if args.all: - # Start MyPy first so combined typing runs follow the requested - # ordering even though completion still depends on runtime duration. - run_aggregate_mypy(args.project, []) run_aggregate_pyright(args.project, []) - return - # Preserve the same "MyPy first" ordering for the per-package fan-out - # path as well. - run_fan_out(["mypy", "pyright"], args.project, []) + else: + run_fan_out(["pyright"], args.project, []) + run_test_typing(args.project, []) return if args.command == "test": @@ -655,7 +775,7 @@ def main() -> None: check_selected=False, extra_args=[], ) - run_sample_pyright([]) + run_sample_typing(_gating_checker_args()) return run_syntax( project_pattern=args.project, @@ -665,6 +785,7 @@ def main() -> None: extra_args=[], ) run_fan_out(["pyright"], args.project, []) + run_test_typing(args.project, _gating_checker_args()) run_fan_out(["test"], args.project, []) # Sample validation and markdown lint are intentionally workspace-wide; # a package-scoped check should stay focused on the selected package set. @@ -676,7 +797,7 @@ def main() -> None: check_selected=False, extra_args=[], ) - run_sample_pyright([]) + run_sample_typing(_gating_checker_args()) run_markdown_code_lint() return @@ -685,9 +806,9 @@ def main() -> None: run_prek_check(args.files) return - if args.command == "ci-mypy": + if args.command in ("ci-mypy", "ci-test-typing"): ensure_no_extra_args(args.command, extra_args) - run_ci_mypy() + run_ci_test_typing() return print(f"[red]Unsupported command: {args.command}[/red]") diff --git a/python/tests/samples/getting_started/test_agent_samples.py b/python/tests/samples/getting_started/test_agent_samples.py index e310521b10d..f9ffb5f0077 100644 --- a/python/tests/samples/getting_started/test_agent_samples.py +++ b/python/tests/samples/getting_started/test_agent_samples.py @@ -7,142 +7,142 @@ import pytest from pytest import MonkeyPatch, mark, param -from samples.getting_started.agents.azure_ai.azure_ai_basic import ( +from samples.getting_started.agents.azure_ai.azure_ai_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_basic, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_code_interpreter import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_code_interpreter import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_with_code_interpreter, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_existing_agent import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_existing_agent import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_with_existing_agent, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_explicit_settings import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_with_explicit_settings, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] mixed_tools_example as azure_ai_with_function_tools_mixed, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] tools_on_agent_level as azure_ai_with_function_tools_agent, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] tools_on_run_level as azure_ai_with_function_tools_run, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_local_mcp import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_local_mcp import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_with_local_mcp, ) -from samples.getting_started.agents.azure_ai.azure_ai_with_thread import ( +from samples.getting_started.agents.azure_ai.azure_ai_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_with_thread, ) -from samples.getting_started.agents.azure_openai.azure_assistants_basic import ( +from samples.getting_started.agents.azure_openai.azure_assistants_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_basic, ) -from samples.getting_started.agents.azure_openai.azure_assistants_with_code_interpreter import ( +from samples.getting_started.agents.azure_openai.azure_assistants_with_code_interpreter import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_with_code_interpreter, ) -from samples.getting_started.agents.azure_openai.azure_assistants_with_existing_assistant import ( +from samples.getting_started.agents.azure_openai.azure_assistants_with_existing_assistant import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_with_existing_assistant, ) -from samples.getting_started.agents.azure_openai.azure_assistants_with_explicit_settings import ( +from samples.getting_started.agents.azure_openai.azure_assistants_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_with_explicit_settings, ) -from samples.getting_started.agents.azure_openai.azure_assistants_with_function_tools import ( +from samples.getting_started.agents.azure_openai.azure_assistants_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_with_function_tools, ) -from samples.getting_started.agents.azure_openai.azure_assistants_with_thread import ( +from samples.getting_started.agents.azure_openai.azure_assistants_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_with_thread, ) -from samples.getting_started.agents.azure_openai.azure_chat_client_basic import ( +from samples.getting_started.agents.azure_openai.azure_chat_client_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_chat_client_basic, ) -from samples.getting_started.agents.azure_openai.azure_chat_client_with_explicit_settings import ( +from samples.getting_started.agents.azure_openai.azure_chat_client_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_chat_client_with_explicit_settings, ) -from samples.getting_started.agents.azure_openai.azure_chat_client_with_function_tools import ( +from samples.getting_started.agents.azure_openai.azure_chat_client_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_chat_client_with_function_tools, ) -from samples.getting_started.agents.azure_openai.azure_chat_client_with_thread import ( +from samples.getting_started.agents.azure_openai.azure_chat_client_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_chat_client_with_thread, ) -from samples.getting_started.agents.azure_openai.azure_responses_client_basic import ( +from samples.getting_started.agents.azure_openai.azure_responses_client_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client_basic, ) -from samples.getting_started.agents.azure_openai.azure_responses_client_with_code_interpreter import ( +from samples.getting_started.agents.azure_openai.azure_responses_client_with_code_interpreter import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client_with_code_interpreter, ) -from samples.getting_started.agents.azure_openai.azure_responses_client_with_explicit_settings import ( +from samples.getting_started.agents.azure_openai.azure_responses_client_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client_with_explicit_settings, ) -from samples.getting_started.agents.azure_openai.azure_responses_client_with_function_tools import ( +from samples.getting_started.agents.azure_openai.azure_responses_client_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client_with_function_tools, ) -from samples.getting_started.agents.azure_openai.azure_responses_client_with_thread import ( +from samples.getting_started.agents.azure_openai.azure_responses_client_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client_with_thread, ) -from samples.getting_started.agents.openai.openai_assistants_basic import ( +from samples.getting_started.agents.openai.openai_assistants_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_basic, ) -from samples.getting_started.agents.openai.openai_assistants_with_code_interpreter import ( +from samples.getting_started.agents.openai.openai_assistants_with_code_interpreter import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_code_interpreter, ) -from samples.getting_started.agents.openai.openai_assistants_with_existing_assistant import ( +from samples.getting_started.agents.openai.openai_assistants_with_existing_assistant import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_existing_assistant, ) -from samples.getting_started.agents.openai.openai_assistants_with_explicit_settings import ( +from samples.getting_started.agents.openai.openai_assistants_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_explicit_settings, ) -from samples.getting_started.agents.openai.openai_assistants_with_file_search import ( +from samples.getting_started.agents.openai.openai_assistants_with_file_search import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_file_search, ) -from samples.getting_started.agents.openai.openai_assistants_with_function_tools import ( +from samples.getting_started.agents.openai.openai_assistants_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_function_tools, ) -from samples.getting_started.agents.openai.openai_assistants_with_thread import ( +from samples.getting_started.agents.openai.openai_assistants_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_with_thread, ) -from samples.getting_started.agents.openai.openai_chat_client_basic import ( +from samples.getting_started.agents.openai.openai_chat_client_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_basic, ) -from samples.getting_started.agents.openai.openai_chat_client_with_explicit_settings import ( +from samples.getting_started.agents.openai.openai_chat_client_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_with_explicit_settings, ) -from samples.getting_started.agents.openai.openai_chat_client_with_function_tools import ( +from samples.getting_started.agents.openai.openai_chat_client_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_with_function_tools, ) -from samples.getting_started.agents.openai.openai_chat_client_with_local_mcp import ( +from samples.getting_started.agents.openai.openai_chat_client_with_local_mcp import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_with_local_mcp, ) -from samples.getting_started.agents.openai.openai_chat_client_with_thread import ( +from samples.getting_started.agents.openai.openai_chat_client_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_with_thread, ) -from samples.getting_started.agents.openai.openai_chat_client_with_web_search import ( +from samples.getting_started.agents.openai.openai_chat_client_with_web_search import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client_with_web_search, ) -from samples.getting_started.agents.openai.openai_responses_client_basic import ( +from samples.getting_started.agents.openai.openai_responses_client_basic import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_basic, ) -from samples.getting_started.agents.openai.openai_responses_client_reasoning import ( +from samples.getting_started.agents.openai.openai_responses_client_reasoning import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_reasoning, ) -from samples.getting_started.agents.openai.openai_responses_client_with_code_interpreter import ( +from samples.getting_started.agents.openai.openai_responses_client_with_code_interpreter import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_code_interpreter, ) -from samples.getting_started.agents.openai.openai_responses_client_with_explicit_settings import ( +from samples.getting_started.agents.openai.openai_responses_client_with_explicit_settings import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_explicit_settings, ) -from samples.getting_started.agents.openai.openai_responses_client_with_file_search import ( +from samples.getting_started.agents.openai.openai_responses_client_with_file_search import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_file_search, ) -from samples.getting_started.agents.openai.openai_responses_client_with_function_tools import ( +from samples.getting_started.agents.openai.openai_responses_client_with_function_tools import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_function_tools, ) -from samples.getting_started.agents.openai.openai_responses_client_with_local_mcp import ( +from samples.getting_started.agents.openai.openai_responses_client_with_local_mcp import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_local_mcp, ) -from samples.getting_started.agents.openai.openai_responses_client_with_thread import ( +from samples.getting_started.agents.openai.openai_responses_client_with_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_thread, ) -from samples.getting_started.agents.openai.openai_responses_client_with_web_search import ( +from samples.getting_started.agents.openai.openai_responses_client_with_web_search import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client_with_web_search, ) @@ -591,4 +591,4 @@ def mock_input(prompt: str = "") -> str: return responses.pop(0) if responses else "exit" monkeypatch.setattr("builtins.input", mock_input) - await sample + await sample() diff --git a/python/tests/samples/getting_started/test_chat_client_samples.py b/python/tests/samples/getting_started/test_chat_client_samples.py index b145ba84e00..008ab7c7145 100644 --- a/python/tests/samples/getting_started/test_chat_client_samples.py +++ b/python/tests/samples/getting_started/test_chat_client_samples.py @@ -7,28 +7,28 @@ import pytest from pytest import MonkeyPatch, mark, param -from samples.getting_started.client.azure_ai_chat_client import ( +from samples.getting_started.client.azure_ai_chat_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_ai_chat_client, ) -from samples.getting_started.client.azure_assistants_client import ( +from samples.getting_started.client.azure_assistants_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_assistants_client, ) -from samples.getting_started.client.azure_chat_client import ( +from samples.getting_started.client.azure_chat_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_chat_client, ) -from samples.getting_started.client.azure_responses_client import ( +from samples.getting_started.client.azure_responses_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as azure_responses_client, ) -from samples.getting_started.client.chat_response_cancellation import ( +from samples.getting_started.client.chat_response_cancellation import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as chat_response_cancellation, ) -from samples.getting_started.client.openai_assistants_client import ( +from samples.getting_started.client.openai_assistants_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_assistants_client, ) -from samples.getting_started.client.openai_chat_client import ( +from samples.getting_started.client.openai_chat_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_chat_client, ) -from samples.getting_started.client.openai_responses_client import ( +from samples.getting_started.client.openai_responses_client import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] main as openai_responses_client, ) @@ -130,4 +130,4 @@ def mock_input(prompt: str = "") -> str: return responses.pop(0) if responses else "exit" monkeypatch.setattr("builtins.input", mock_input) - await sample + await sample() diff --git a/python/tests/samples/getting_started/test_threads_samples.py b/python/tests/samples/getting_started/test_threads_samples.py index d0630d2181b..f4102f19d2f 100644 --- a/python/tests/samples/getting_started/test_threads_samples.py +++ b/python/tests/samples/getting_started/test_threads_samples.py @@ -7,8 +7,12 @@ import pytest from pytest import MonkeyPatch, mark, param -from samples.getting_started.threads.custom_chat_message_store_thread import main as threads_custom_store -from samples.getting_started.threads.suspend_resume_thread import main as threads_suspend_resume +from samples.getting_started.threads.custom_chat_message_store_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] + main as threads_custom_store, +) +from samples.getting_started.threads.suspend_resume_thread import ( # pyrefly: ignore[missing-import] # ty: ignore[unresolved-import] + main as threads_suspend_resume, +) # Environment variable for controlling sample tests RUN_SAMPLES_TESTS = "RUN_SAMPLES_TESTS" @@ -49,4 +53,4 @@ def mock_input(prompt: str = "") -> str: return responses.pop(0) if responses else "exit" monkeypatch.setattr("builtins.input", mock_input) - await sample + await sample() diff --git a/python/tests/samples/hosting/test_toolbox_endpoint.py b/python/tests/samples/hosting/test_toolbox_endpoint.py index b08c5e58a9c..e4bfde8b6a5 100644 --- a/python/tests/samples/hosting/test_toolbox_endpoint.py +++ b/python/tests/samples/hosting/test_toolbox_endpoint.py @@ -6,7 +6,6 @@ implementation of _resolve_toolbox_endpoint(). """ -import importlib import importlib.util import sys from pathlib import Path @@ -36,8 +35,10 @@ def _load_sample(subdir: str, module_alias: str): spec = importlib.util.spec_from_file_location(module_alias, _RESPONSES_DIR / subdir / "main.py") - mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] - spec.loader.exec_module(mod) # type: ignore[union-attr] + assert spec is not None + mod = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(mod) return mod diff --git a/python/ty.samples.toml b/python/ty.samples.toml new file mode 100644 index 00000000000..bf297701d7e --- /dev/null +++ b/python/ty.samples.toml @@ -0,0 +1,17 @@ +# Basic-mode ty profile for samples. Mirrors pyrefly.samples.toml: catch real mistakes +# in teaching code without forcing casts/overload workarounds for third-party SDK stubs. +[rules] +# Signature/cast/overload noise -- off for samples. +invalid-argument-type = "ignore" +invalid-return-type = "ignore" +no-matching-overload = "ignore" +invalid-assignment = "ignore" +invalid-typed-dict-field = "ignore" +not-subscriptable = "ignore" +missing-argument = "ignore" +unknown-argument = "ignore" +invalid-key = "ignore" +redundant-cast = "ignore" +unused-type-ignore-comment = "ignore" +# Kept on (real bugs a reader would hit): unresolved-import, unresolved-attribute, +# not-iterable, invalid-await. diff --git a/python/uv.lock b/python/uv.lock index 72aa2d01ad6..0ec542068f3 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -130,6 +130,7 @@ dev = [ { name = "opentelemetry-sdk", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "poethepoet", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "prek", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pyrefly", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pyright", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest-asyncio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -140,7 +141,9 @@ dev = [ { name = "rich", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "ruff", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "tomli", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "ty", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "uv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "zuban", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] [package.metadata] @@ -155,6 +158,7 @@ dev = [ { name = "opentelemetry-sdk", specifier = "==1.40.0" }, { name = "poethepoet", specifier = "==0.46.0" }, { name = "prek", specifier = "==0.4.3" }, + { name = "pyrefly", specifier = "==1.0.0" }, { name = "pyright", specifier = "==1.1.408" }, { name = "pytest", specifier = "==9.0.3" }, { name = "pytest-asyncio", specifier = "==1.4.0" }, @@ -165,7 +169,9 @@ dev = [ { name = "rich", specifier = ">=13.7.1,<16.0.0" }, { name = "ruff", specifier = "==0.15.15" }, { name = "tomli", specifier = "==2.4.1" }, + { name = "ty", specifier = "==0.0.46" }, { name = "uv", specifier = "==0.11.17" }, + { name = "zuban", specifier = "==0.8.2" }, ] [[package]] @@ -5948,6 +5954,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, ] +[[package]] +name = "pyrefly" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/3a/9045b0097ac58979c7c30a4fa0e673db942d4adbc7b6d439bd54ae58c441/pyrefly-1.0.0.tar.gz", hash = "sha256:5c2b810ffcebd84be71de5df1223651edee951653a66935c6f091e957c452455", size = 5677995, upload-time = "2026-05-12T20:12:46.812Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/c6/90788819bac9c61dd7bacba53b79f3c12d47ccbe5e51b3d6d89f2387e1d2/pyrefly-1.0.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e355a0908555348ed4b9585ef25c76ff566673e345c866c325f1633f44d890b6", size = 13122950, upload-time = "2026-05-12T20:12:20.711Z" }, + { url = "https://files.pythonhosted.org/packages/82/91/a3cf2a1e87d336eaa804a1e6fc93266faf6dc2a97eecdbc7eae289628022/pyrefly-1.0.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a7038efc3a40f8294edee339895633cf22db268c0d434cdbcbefc34f78a9ecc3", size = 12599494, upload-time = "2026-05-12T20:12:23.495Z" }, + { url = "https://files.pythonhosted.org/packages/cd/ab/74d1e11e737e99b1c003ecc5d7d2e846c4ea1f328966bfdbbd0ac63fad0a/pyrefly-1.0.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da331ca515ed1c08791da2b5f664cf9c1294c48fd802133262e7d5d51e0f4416", size = 12995507, upload-time = "2026-05-12T20:12:25.951Z" }, + { url = "https://files.pythonhosted.org/packages/7c/ac/2df0899f8464c97e5d995f994c97c5cb5b0f58610432aa90d26d924e1db5/pyrefly-1.0.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c74219d8f3e63cdaa5501a0b21d1c9d37011820f9606728d0ed06f09ae86a878", size = 13947693, upload-time = "2026-05-12T20:12:29.188Z" }, + { url = "https://files.pythonhosted.org/packages/6b/3e/b247c24321e36f04b7d51f9ccf3df93e5009e4b29939524b36ec2e17dc2a/pyrefly-1.0.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c0d05543b1bb6ee6d64149eb5d6b2fb15aa72d3962d6a97abca0afaca8b0c131", size = 13925803, upload-time = "2026-05-12T20:12:31.904Z" }, + { url = "https://files.pythonhosted.org/packages/61/16/cfa2d61a4aa1e1f7bca48bb37acd01c6a09db4864b16a54f9587092765ff/pyrefly-1.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1382d5b1fcdb49a4de9f34d112d2bddf290a78ff93ee8149492ad5f1077ddffc", size = 13470398, upload-time = "2026-05-12T20:12:35.302Z" }, + { url = "https://files.pythonhosted.org/packages/cb/2b/6372c7dddb326223e24a46b17efd0d4bd7b4fe22c821e523157577eed2d2/pyrefly-1.0.0-py3-none-win32.whl", hash = "sha256:aa8b5d0e47080e3202a2547b39f7a5a61d2c781c712b3b67884f745ca2c759d2", size = 12222643, upload-time = "2026-05-12T20:12:38.618Z" }, + { url = "https://files.pythonhosted.org/packages/be/ad/1d23be700b6b2ddaeb362360c7145917a8edbbf7240ae428d40541772fce/pyrefly-1.0.0-py3-none-win_amd64.whl", hash = "sha256:c8abcb0f2082e83c890375128f9cff4aa4d3f210b85eea7b3046c1ae764e77f5", size = 13146369, upload-time = "2026-05-12T20:12:41.423Z" }, + { url = "https://files.pythonhosted.org/packages/8c/38/16589134f3012fd097a10dcc85771555f1a5fb76e04b682597180743af30/pyrefly-1.0.0-py3-none-win_arm64.whl", hash = "sha256:d150fa9e40e8392832be81c3bcfc0497c146674ce4d0f8e04e1ec29e775ffb8c", size = 12538326, upload-time = "2026-05-12T20:12:43.996Z" }, +] + [[package]] name = "pyright" version = "1.1.408" @@ -7394,6 +7417,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf", size = 78374, upload-time = "2026-02-03T17:35:50.982Z" }, ] +[[package]] +name = "ty" +version = "0.0.46" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/7d/d95b5a9dea83472006be3ce5e480028c44b34138d84d0172e910f287fb69/ty-0.0.46.tar.gz", hash = "sha256:c6c2d7105b5633b49950b4c3a90d1ed2613eb9d794ad582bbbf6c4ffcb93accf", size = 5832380, upload-time = "2026-06-09T03:28:05.056Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/24/f9f7533c391610521f4164e6b8e37ef72d0c1ee8651bc0d9ce9e658b953b/ty-0.0.46-py3-none-linux_armv6l.whl", hash = "sha256:5e716337994699cbc1a1a7b7a3e6622306f2574c710330f9d9691c2c3d8391b0", size = 11756264, upload-time = "2026-06-09T03:28:20.112Z" }, + { url = "https://files.pythonhosted.org/packages/66/49/ff3d13655b9b5cc8176f4c3446bf7ec2df43c8ad9e5272d4adc5d952fa45/ty-0.0.46-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:51d618dec5403635690d0e3e298cd0ad3d84ebc6a576652939ef30ce96fce4b2", size = 11492723, upload-time = "2026-06-09T03:28:13.23Z" }, + { url = "https://files.pythonhosted.org/packages/82/4a/e7e3209e353c5835c7756339bbcdfda10852407b80fbb9ed46c17241873a/ty-0.0.46-py3-none-macosx_11_0_arm64.whl", hash = "sha256:acbafd6a2351b07a6cf4c945b0b1d47f6d2826faac2526a351dfa74d3a3cc664", size = 10892822, upload-time = "2026-06-09T03:27:51.179Z" }, + { url = "https://files.pythonhosted.org/packages/6c/20/4390c90434a9ddefcecb65e8df00e4c2700e9739dc0baf58bed36d25f713/ty-0.0.46-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de5df602ffd760612ae36602bbad69b0123ff6cffd92e62aa92b7709317d69e3", size = 11408745, upload-time = "2026-06-09T03:27:58.049Z" }, + { url = "https://files.pythonhosted.org/packages/75/0c/f13a1bf9c6798530c773667095a6cf8f73ec9721db359423e7249bff7fbc/ty-0.0.46-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7abf5a10b30d8641faad90f6a19989daec941bb90261159e05cfeb04d2012046", size = 11544432, upload-time = "2026-06-09T03:27:53.519Z" }, + { url = "https://files.pythonhosted.org/packages/56/69/eb3710c13dff846a0362df04fadd8a39b64ccc244c0d02ce5285ede8eae5/ty-0.0.46-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8770404139c6ccee2ce2fc226478cfa4100915133c876c257e52197b8b92051d", size = 12031228, upload-time = "2026-06-09T03:28:29.816Z" }, + { url = "https://files.pythonhosted.org/packages/e9/68/5f5db9c84c1d44acdc67281089b372d9d818ee68123a60c59c66187095e2/ty-0.0.46-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f960d5a6e4860076924d2b86891d9872c4a3daa4663fb416e640b22cf3dbf68e", size = 12596073, upload-time = "2026-06-09T03:28:25.204Z" }, + { url = "https://files.pythonhosted.org/packages/14/be/cfd0bb272e6a1491f6de30c60da1f39c2b3c3524ec64a5c92b71365c9185/ty-0.0.46-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d9000a4a3ed08fc37e8a2ff0b801cde06e1c2af3bc053677744bb5a1b751030", size = 12284885, upload-time = "2026-06-09T03:28:10.58Z" }, + { url = "https://files.pythonhosted.org/packages/a8/3a/2cd541f6320f5d6f70a45725c4e1016efedd5545348bb23b47ffb3e4c724/ty-0.0.46-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1160e6dc86536109ab755f7142f36f4dda5333c8330cf230d61819494d27125", size = 12079480, upload-time = "2026-06-09T03:27:55.847Z" }, + { url = "https://files.pythonhosted.org/packages/de/91/8e0075bc6568fb477e7ef4d805c67fa6902b692cb4419e0bf5ce3c04c5bc/ty-0.0.46-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:b619c0efe007731f8221fa787701bfa4402da7a83eb26c61ae25e77b6ace6384", size = 12316547, upload-time = "2026-06-09T03:28:08.28Z" }, + { url = "https://files.pythonhosted.org/packages/00/28/b96cbfeda019a4044c6a8cd06ff84d08b631d4ba7d9a1e6dc0311df3563a/ty-0.0.46-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ad98fccb6a8a94c4121b993761a0deee602f5826c4162e0a91f4f8118ddadd42", size = 11392846, upload-time = "2026-06-09T03:28:00.418Z" }, + { url = "https://files.pythonhosted.org/packages/3b/d0/4d77f699a95ac7a13b94ca1a58682667cfe974f91557d9e2a9fc0b808a7f/ty-0.0.46-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:74536b13c3cc3f5944408669c202d4c57c3d19ff154732df8e6145718aef9191", size = 11559017, upload-time = "2026-06-09T03:28:17.619Z" }, + { url = "https://files.pythonhosted.org/packages/88/62/1d6f6b51c2b132da8011c6a41ead0c1fd2a0b17ea72304bcf6ce084d581a/ty-0.0.46-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5e50b1e96ced41b609e24ed27d9e4f508584ed7f4d0bb717ca8c8d75d2fd1b7c", size = 11666509, upload-time = "2026-06-09T03:28:22.454Z" }, + { url = "https://files.pythonhosted.org/packages/fe/9a/6643894bc12cb30c281f4c8bf37f6d30c1fbd9484ef39a12b0ea6dae3c1c/ty-0.0.46-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0a7d9f58d26d938e5d2f607481b7a412d8c00d675a1ec72004fa9d6b3b9def99", size = 12180448, upload-time = "2026-06-09T03:28:32.329Z" }, + { url = "https://files.pythonhosted.org/packages/86/68/0f3b7bb03a7da676ef51b1c0af0bde1e500d69d5f0c807ed63b6f30b66dd/ty-0.0.46-py3-none-win32.whl", hash = "sha256:26db0ce89c573e60132d14e9688c9329a1633b1a8c26fe457025c7c406f7d5e6", size = 10960002, upload-time = "2026-06-09T03:28:02.832Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f4/91ff618b2dee39d0633d23e1adac0174aa1de80df17e270acac534034dbc/ty-0.0.46-py3-none-win_amd64.whl", hash = "sha256:90e8e6d446b9cb7cb4bede9fca7b3c99fd1e2355605ecf431c131a51db2a5e93", size = 12097413, upload-time = "2026-06-09T03:28:27.495Z" }, + { url = "https://files.pythonhosted.org/packages/e5/2e/300174fca375a27a7c28dd80e990d857d7b3e3b25980c65063f980aa2f17/ty-0.0.46-py3-none-win_arm64.whl", hash = "sha256:ebd320d82605079b901a095dc4711037a0c488b4ace79a602fef4df0d3f4cf74", size = 11439595, upload-time = "2026-06-09T03:28:15.355Z" }, +] + [[package]] name = "typer" version = "0.23.1" @@ -8017,3 +8065,22 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] + +[[package]] +name = "zuban" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/d8/9a24dc2c22250fc416bff06c4ac4664c73ef7ec558b8d23049f493af73b5/zuban-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:07d3a498c514cb51ec881987b0e841043ea466674a9b3bed8411890bacb4d0e3", size = 11486917, upload-time = "2026-06-08T23:29:02.829Z" }, + { url = "https://files.pythonhosted.org/packages/23/62/3636a5474a9f9361ca83030cb82fefb4b1ff7771c75ab957c9025a3a7a97/zuban-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:64ed57aea375c74671d81e4d4655328a47b2157c7827e2c50cc90f80b29e8417", size = 11205360, upload-time = "2026-06-08T23:29:05.612Z" }, + { url = "https://files.pythonhosted.org/packages/49/41/0650f89e1f2ea0e865dd5e23614e6f3a205153032f0a9e970c5f4da885d1/zuban-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8bc47f343a939407ff881b0e9f0da36c7a3a509b3e6ca237512aadf30d02053b", size = 28385720, upload-time = "2026-06-08T23:29:08.177Z" }, + { url = "https://files.pythonhosted.org/packages/d7/52/58ecbdbf9668fd81c85f69b94609a5a813a94a18d9ce73ee1548cf85d98d/zuban-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:26ede2418f48affde0e9092108ec0be344b1323d7e87d11ecac8e5e89d383d7e", size = 28603725, upload-time = "2026-06-08T23:29:11.778Z" }, + { url = "https://files.pythonhosted.org/packages/96/ad/efc2e98a4492d010459161e03f98952227744c8f5ae8ebcc13ca88d50ad6/zuban-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b83e0cc35cccb2f60b50d8c6658ca3ac332914ad58d8e3d3d5886a13f761771b", size = 29769216, upload-time = "2026-06-08T23:29:14.865Z" }, + { url = "https://files.pythonhosted.org/packages/f7/60/a40add4cf31694f64727a34fc4a2093c599f1c911fdb6fbf61bf40d437f5/zuban-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c4b26c1a383234c340d4956c04d8528e5e1654a08571d9d3eaced4973d21f87", size = 30851896, upload-time = "2026-06-08T23:29:18.04Z" }, + { url = "https://files.pythonhosted.org/packages/17/6f/7e32ec7c1677dfd7ac5821ac1d6cc6d4552f65fe06b098bd3515f5c5699d/zuban-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:32699e64acdb7d8984eddc3095b8f6fcf37fa00c52d54e74fed920440fe36ebe", size = 28551589, upload-time = "2026-06-08T23:29:20.922Z" }, + { url = "https://files.pythonhosted.org/packages/68/26/f5e8434aa8ea478a949b4dab8bd94abb47ff46f4777943a3309c33bfda94/zuban-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:f2ada854ceddeae6f7a8ecf709577235f5845770ebd6d51b26dde4358e78b645", size = 29007499, upload-time = "2026-06-08T23:29:23.954Z" }, + { url = "https://files.pythonhosted.org/packages/5c/37/70e6696e0e9d202fa84be8437b9dbf32e6d5854d9a201b9ca24965496391/zuban-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:dd6c2be104e6e9e1693e15d0a7775c5841a9793534351738c24b7377c58b487c", size = 29670474, upload-time = "2026-06-08T23:29:26.769Z" }, + { url = "https://files.pythonhosted.org/packages/74/0c/1c6f0a239fc2c93c46e733fa5d6f9ac376f9e0aec909467d50f6071234bf/zuban-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:299debc250b729abb0ebecd0bd53cac3c6b26cdb32dc4e79e601f9a1f584d3cd", size = 29029010, upload-time = "2026-06-08T23:29:29.959Z" }, + { url = "https://files.pythonhosted.org/packages/a6/7b/d5cdb140e9979d3f7784d9565f21692cafb851cceb21cba13348bcfcfa30/zuban-0.8.2-py3-none-win32.whl", hash = "sha256:3521196b132c2650e01a085d211dfcd849de9e9409e193e8f8ad85629c60fa42", size = 10026570, upload-time = "2026-06-08T23:29:32.596Z" }, + { url = "https://files.pythonhosted.org/packages/ee/cb/25a952a3594aa7ee32dee648513a1ea758c2def98742966d24484bb3d8ab/zuban-0.8.2-py3-none-win_amd64.whl", hash = "sha256:07463d337d293efb918990900c52b976646810b8f741999e8f288fb3da6e6594", size = 10798493, upload-time = "2026-06-08T23:29:34.805Z" }, +]