Skip to content

Add BaseAIHook and Update usages#67438

Open
gopidesupavan wants to merge 2 commits into
apache:mainfrom
gopidesupavan:add-baseaihook
Open

Add BaseAIHook and Update usages#67438
gopidesupavan wants to merge 2 commits into
apache:mainfrom
gopidesupavan:add-baseaihook

Conversation

@gopidesupavan
Copy link
Copy Markdown
Member

@gopidesupavan gopidesupavan commented May 24, 2026

Add BaseAIHook and Update usage

Summary

Introduce BaseAIHook, a backend-neutral contract for multi-turn LLM agents in the common-ai provider. AgentOperator and @task.agent now resolve the agent runtime from the connection conn_type (for example pydanticai, pydanticai-bedrock, pydanticai-azure) and delegate all framework-specific work to the hook.

PydanticAIHook is the first implementation. All LLM operators and LLMRetryPolicy are migrated to a shared AgentRunRequest / run_agent API. SQLToolset is migrated to the new framework-agnostic BaseToolset interface.

This lays the foundation for additional agent backends without adding parallel operator classes per framework. A follow-up PR will add AWS Strands as the next hook implementation; this contract also opens the door for Google ADK and other agent runtimes behind the same AgentOperator / @task.agent surface.


Motivation

Before this change:

  • AgentOperator contained pydantic-ai-specific logic (tool wrapping, durable caching, agent construction).
  • PydanticAIHook.create_agent() / run_agent() used ad-hoc keyword arguments.
  • Tool logging and durable execution were handled in the operator layer via LoggingToolset / CachingToolset wrappers.

The operator should stay framework-agnostic. Hooks should own agent lifecycle, tool resolution, durable execution, and normalized results.


Design

BaseAIHook contract

New abstract hook with:

Method / property Purpose
get_model() Return backend model/client
get_conn() Compatibility shim → get_model()
create_agent(request) Build (but do not run) the agent
run_agent(agent, request) Execute and return AgentRunResult
_tool_spec_to_native(spec) Convert ToolSpec → native tool representation
get_agent_hook(conn_id) Resolve hook from connection conn_type

Capability flags: supports_toolsets, supports_durable, supports_usage_limits.

Parameter objects

  • AgentRunRequest — prompt, output type, instructions, toolsets, usage limits, message history, durable context, agent params
  • AgentRunResult — output, message history, model name, usage, tool names, durable stats
  • ToolSpec — framework-neutral tool descriptor (name, description, JSON schema, callable)
  • BaseToolset — abstract as_tools() → list[ToolSpec]
  • DurableContext / DurableStats — durable execution identity and cache statistics

Shared hook helpers

Moved into BaseAIHook:

  • _resolve_tools() — converts BaseToolset, plain callables, and native tool objects
  • _logged_callable() — per-tool real-time logging
  • _cached_callable() — per-tool durable step caching
  • _init_durable()DurableStorage / DurableStepCounter setup

PydanticAIHook implementation

  • Implements full BaseAIHook contract
  • Splits toolsets into two paths:
    • AbstractToolset (HookToolset, MCPToolset, DataFusionToolset, third-party) → Agent(toolsets=[...]) with LoggingToolset / CachingToolset wrapping when enabled
    • BaseToolset / callables / native Tool → resolved via _resolve_toolsAgent(tools=[...])
  • Durable model caching via CachingModel in run_agent
  • get_model() replaces direct get_conn() usage; get_conn() delegates for backward compatibility

AgentOperator thinning

Operator execution is now:

request = self._build_request(prompt=self.prompt)
agent = self.llm_hook.create_agent(request)
run_result = self.llm_hook.run_agent(agent, request)

No pydantic-ai imports at runtime (except UsageLimits under TYPE_CHECKING).

Early validation via _validate_hook_capabilities() checks hook support for toolsets, durable, and usage limits.

SQLToolsetBaseToolset

SQLToolset no longer implements pydantic-ai's AbstractToolset. It implements BaseToolset.as_tools() returning four ToolSpec objects with JSON schemas (list_tables, get_schema, query, check_query).

HookToolset, MCPToolset, and DataFusionToolset remain AbstractToolset and continue to work unchanged through the pydantic-ai routing path.


Other changes

All LLM operators migrated

These now use BaseAIHook.get_agent_hook() and AgentRunRequest:

  • LLMOperator
  • LLMBranchOperator
  • LLMSQLOperator
  • LLMSchemaCompareOperator
  • LLMFileAnalysisOperator
  • LLMRetryPolicy

Logging utilities

  • log_run_summary() now accepts AgentRunResult directly
  • Removed wrap_toolsets_for_logging() from the operator path; logging is handled in the hook layer

Examples and docs

  • Updated example_pydantic_ai_hook.py to use BaseAIHook.get_agent_hook() + AgentRunRequest
  • Updated docs/operators/agent.rst, docs/toolsets.rst, AGENTS.md
  • Changelog entry for the new contract

Tests

  • New test_base_ai.py — dataclasses, _resolve_tools, logging/caching wrappers
  • Expanded test_pydantic_ai.py — contract, durable init, AbstractToolset routing/wrapping
  • Updated operator, decorator, and policy tests to mock BaseAIHook and assert AgentRunRequest forwarding
  • Rewritten test_sql.py for BaseToolset.as_tools() API

Breaking changes

PydanticAIHook API

Before:

agent = hook.create_agent(output_type=str, instructions="...", toolsets=[...])
result = hook.run_agent(agent, prompt="hello", usage_limits=limits)

After:

request = AgentRunRequest(prompt="hello", output_type=str, instructions="...", toolsets=[...], usage_limits=limits)
agent = hook.create_agent(request)
result = hook.run_agent(agent, request)

get_conn() still works (delegates to get_model()).

SQLToolset direct pydantic-ai usage

Before: pass SQLToolset(...) directly to pydantic-ai Agent(toolsets=[...]).

After: use via AgentOperator / @task.agent, or build through the hook:

request = AgentRunRequest(prompt="...", toolsets=[SQLToolset(db_conn_id="my_db")])
agent = hook.create_agent(request)
result = hook.run_agent(agent, request)

SQLToolset is now a BaseToolset, not an AbstractToolset.


Migration guide

Custom code calling PydanticAIHook directly

Replace kwargs-style create_agent / run_agent with AgentRunRequest:

from airflow.providers.common.ai.hooks.base_ai import AgentRunRequest, BaseAIHook

hook = BaseAIHook.get_agent_hook("pydanticai_default", hook_params={"model_id": "openai:gpt-5"})
request = AgentRunRequest(
    prompt="Analyze this dataset",
    output_type=str,
    instructions="You are a data analyst.",
    toolsets=[SQLToolset(db_conn_id="postgres_default")],
)
agent = hook.create_agent(request)
result = hook.run_agent(agent, request)
print(result.output)

DAG authors using operators / decorators

No DAG changes required for:

  • AgentOperator / @task.agent
  • LLMOperator / @task.llm
  • Other LLM decorators

Connection conn_type continues to select the backend.

Adding a new agent backend

Subclass BaseAIHook and implement:

  1. get_model()
  2. create_agent(request)
  3. run_agent(agent, request)
  4. _tool_spec_to_native(spec)

Register the hook in provider.yaml. Reuse shared helpers (_resolve_tools, _logged_callable, _cached_callable, _init_durable) where applicable.


Known limitations / follow-ups

  • HookToolset / DataFusionToolset could be migrated to BaseToolset in a follow-up; they work today via the AbstractToolset pass-through path.

Test plan

  • tests/unit/common/ai/hooks/test_base_ai.py

  • tests/unit/common/ai/hooks/test_pydantic_ai.py

  • tests/unit/common/ai/operators/test_agent.py

  • tests/unit/common/ai/operators/test_llm.py

  • tests/unit/common/ai/operators/test_llm_branch.py

  • tests/unit/common/ai/operators/test_llm_sql.py

  • tests/unit/common/ai/operators/test_llm_schema_compare.py

  • tests/unit/common/ai/operators/test_llm_file_analysis.py

  • tests/unit/common/ai/decorators/test_agent.py

  • tests/unit/common/ai/decorators/test_llm*.py

  • tests/unit/common/ai/policies/test_retry.py

  • tests/unit/common/ai/toolsets/test_sql.py

  • tests/unit/common/ai/utils/test_logging.py

  • Follow-up: AWS Strands agent hook (StrandsAIHook implementing BaseAIHook)

  • Follow-up: Google ADK agent hook (same contract, new conn_type registration)

  • Follow-up: migrate HookToolset, MCPToolset, and DataFusionToolset from pydantic-ai AbstractToolset to BaseToolset

Was generative AI tooling used to co-author this PR?
  • Yes
  • No

Was generative AI tooling used to co-author this PR?
  • Yes (please specify the tool below)

  • Read the Pull Request Guidelines for more information. Note: commit author/co-author name and email in commits become permanently public when merged.
  • For fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
  • When adding dependency, check compliance with the ASF 3rd Party License Policy.
  • For significant user-facing changes create newsfragment: {pr_number}.significant.rst, in airflow-core/newsfragments. You can add this file in a follow-up commit after the PR is created so you know the PR number.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new BaseAIHook contract in the common-ai provider to make multi-turn agent execution backend-neutral. Operators/decorators now construct an AgentRunRequest, resolve the runtime hook from the connection conn_type, and delegate agent lifecycle/tool resolution/durable execution to the hook implementation (starting with PydanticAIHook).

Changes:

  • Add BaseAIHook + shared request/response/tool abstractions (AgentRunRequest, AgentRunResult, ToolSpec, BaseToolset) and shared helper logic (tool resolution, logging, caching).
  • Refactor AgentOperator, LLM operators/decorators, and LLMRetryPolicy to use get_agent_hook() and the shared create_agent(request) / run_agent(agent, request) flow.
  • Migrate SQLToolset to the framework-agnostic BaseToolset interface and update logging utilities, docs, examples, and tests accordingly.

Reviewed changes

Copilot reviewed 34 out of 35 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
uv.lock Lockfile update (adds additional jpype1 wheel entries).
providers/common/ai/tests/unit/common/ai/utils/test_logging.py Update logging tests to validate AgentRunResult-based summaries.
providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py Rewrite SQLToolset tests for BaseToolset.as_tools() + direct tool callables.
providers/common/ai/tests/unit/common/ai/policies/test_retry.py Update retry policy tests to mock BaseAIHook.get_agent_hook() + request forwarding.
providers/common/ai/tests/unit/common/ai/operators/test_llm.py Update LLMOperator tests to assert AgentRunRequest construction and run_agent usage.
providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py Update LLMSQL operator tests for request-based hook invocation.
providers/common/ai/tests/unit/common/ai/operators/test_llm_schema_compare.py Update schema compare operator tests to validate request contents + run_agent call.
providers/common/ai/tests/unit/common/ai/operators/test_llm_file_analysis.py Update file analysis operator tests for BaseAIHook + request flow.
providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py Update branch operator tests to use BaseAIHook and request-based execution.
providers/common/ai/tests/unit/common/ai/operators/test_agent.py Update AgentOperator tests for capability validation + request/durable context forwarding.
providers/common/ai/tests/unit/common/ai/hooks/test_pydantic_ai.py Expand PydanticAIHook tests for BaseAIHook contract, tool routing, and durable behavior.
providers/common/ai/tests/unit/common/ai/hooks/test_base_ai.py Add new unit tests covering BaseAIHook dataclasses and tool/log/cache helpers.
providers/common/ai/tests/unit/common/ai/decorators/test_llm.py Update @task.llm tests to mock BaseAIHook.get_agent_hook() and validate request prompt.
providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py Update @task.llm_sql tests for request-based hook execution.
providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py Update schema compare decorator tests for BaseAIHook.get_agent_hook() flow.
providers/common/ai/tests/unit/common/ai/decorators/test_llm_file_analysis.py Update file analysis decorator tests for request-based execution.
providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py Update branch decorator tests to mock BaseAIHook.get_agent_hook() and validate behavior.
providers/common/ai/tests/unit/common/ai/decorators/test_agent.py Update @task.agent tests for request forwarding and toolset passthrough.
providers/common/ai/src/airflow/providers/common/ai/utils/logging.py Make logging backend-neutral by consuming AgentRunResult directly.
providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py Convert SQLToolset from pydantic-ai AbstractToolset to framework-neutral BaseToolset.
providers/common/ai/src/airflow/providers/common/ai/policies/retry.py Migrate LLMRetryPolicy to use BaseAIHook + AgentRunRequest and run_agent.
providers/common/ai/src/airflow/providers/common/ai/operators/llm.py Refactor LLMOperator to build AgentRunRequest and call hook create_agent/run_agent.
providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py Refactor LLMSQL operator to use AgentRunRequest and hook execution.
providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py Refactor schema compare operator to request-based hook execution.
providers/common/ai/src/airflow/providers/common/ai/operators/llm_file_analysis.py Refactor file analysis operator to request-based hook execution.
providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py Refactor branch operator to request-based hook execution.
providers/common/ai/src/airflow/providers/common/ai/operators/agent.py Thin AgentOperator: capability validation + request building + hook-driven execution/durable/tool logging.
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py Implement BaseAIHook contract for pydantic-ai, including tool routing and durable execution support.
providers/common/ai/src/airflow/providers/common/ai/hooks/base_ai.py Add new BaseAIHook contract, request/result dataclasses, tool abstraction, and shared helpers.
providers/common/ai/src/airflow/providers/common/ai/example_dags/example_pydantic_ai_hook.py Update example DAG to use BaseAIHook.get_agent_hook() + AgentRunRequest.
providers/common/ai/docs/toolsets.rst Update toolset docs to describe mixed toolset routing and BaseToolset usage.
providers/common/ai/docs/operators/agent.rst Update AgentOperator docs to explain backend selection via connection conn_type and new toolset shapes.
providers/common/ai/docs/hooks/index.rst Update hook selection docs to reflect conn_type-driven backend selection for agents.
providers/common/ai/docs/changelog.rst Add changelog entry for BaseAIHook contract introduction.
providers/common/ai/AGENTS.md Update contributor guidance to describe BaseAIHook and backend-neutral agent design.

Comment on lines +269 to +273
if self._durable_storage is not None and self._durable_counter is not None:
from airflow.providers.common.ai.durable.caching_model import CachingModel

resolved_model = infer_model(agent.model)
caching_model = CachingModel(
Comment on lines +287 to +288
elif inspect.isfunction(ts):
specs = [ToolSpec(name=ts.__name__, description=ts.__doc__ or "", parameters={}, fn=ts)]
Comment on lines 85 to +88
the durable execution (step-level caching with retry replay), HITL review
integration, and automatic tool call logging that ``AgentOperator`` provides.
integration, and the automatic tool call logging and routing that
``AgentOperator`` provides via
:class:`~airflow.providers.common.ai.toolsets.logging.LoggingToolset`.
Comment on lines +327 to +343
storage = MagicMock()
counter = MagicMock()
counter.next_step.return_value = 1
storage.load_tool_result.return_value = (True, "cached_value")

calls = []

def fn():
calls.append(1)
return "computed"

wrapped = BaseAIHook._cached_callable(fn, storage, counter)
result = wrapped()

assert result == "cached_value"
assert calls == []
counter.replayed_tool += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants