diff --git a/agent-schema.json b/agent-schema.json index e70a17de6..98afc085d 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -1825,7 +1825,7 @@ }, "allow_private_ips": { "type": "boolean", - "description": "Opt in to dialling non-public IP addresses (valid for type 'fetch', 'api', 'openapi', and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply." + "description": "Opt in to dialling non-public IP addresses (valid for type 'fetch', 'api', 'openapi', 'a2a', and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply." }, "url": { "type": "string", diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 104f0caa4..94f00f3e2 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -883,7 +883,7 @@ type Toolset struct { // For the `lsp` tool FileTypes []string `json:"file_types,omitempty"` - // HTTP timeout in seconds for `fetch`, `api`, and `openapi` toolsets. + // HTTP timeout in seconds for `fetch`, `api`, `openapi`, and `a2a` toolsets. // Defaults to 30 seconds when omitted. Timeout int `json:"timeout,omitempty"` @@ -899,8 +899,8 @@ type Toolset struct { // `allowed_domains`. BlockedDomains []string `json:"blocked_domains,omitempty" yaml:"blocked_domains,omitempty"` - // For the `fetch`, `api`, `openapi` and remote `mcp` toolsets — opt in to - // dialling non-public IP addresses. + // For the `fetch`, `api`, `openapi`, `a2a` and remote `mcp` toolsets — opt in + // to dialling non-public IP addresses. // // By default, protected HTTP clients refuse connections (after DNS // resolution, so DNS rebinding is also blocked) to loopback (127/8, diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index d6ac582c1..18c16e683 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -145,8 +145,8 @@ func (t *Toolset) validate() error { if len(t.BlockedDomains) > 0 && t.Type != "fetch" { return errors.New("blocked_domains can only be used with type 'fetch'") } - if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" && t.Type != "api" && t.Type != "openapi" { - return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets") + if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" && t.Type != "api" && t.Type != "openapi" && t.Type != "a2a" { + return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi', 'a2a' or remote MCP toolsets") } if len(t.AllowedDomains) > 0 && len(t.BlockedDomains) > 0 { return errors.New("allowed_domains and blocked_domains are mutually exclusive") @@ -235,7 +235,7 @@ func (t *Toolset) validate() error { return errors.New("either command, remote or ref must be set, but only one of those") } if t.AllowPrivateIPsEnabled() && t.Remote.URL == "" && t.Ref == "" { - return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets") + return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi', 'a2a' or remote MCP toolsets") } if t.Remote.OAuth != nil { if t.Remote.URL == "" { diff --git a/pkg/config/toolset_validate_test.go b/pkg/config/toolset_validate_test.go index 4dfd2c1e0..ace283df6 100644 --- a/pkg/config/toolset_validate_test.go +++ b/pkg/config/toolset_validate_test.go @@ -294,7 +294,7 @@ agents: - type: shell allow_private_ips: true `, - wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets", + wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi', 'a2a' or remote MCP toolsets", }, { name: "allow_private_ips on fetch toolset is accepted", @@ -333,6 +333,18 @@ agents: - type: openapi url: http://10.0.0.1/openapi.json allow_private_ips: true +`, + }, + { + name: "allow_private_ips on a2a toolset is accepted", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: a2a + url: http://10.0.0.1/.well-known/agent-card.json + allow_private_ips: true `, }, { @@ -360,7 +372,7 @@ agents: allow_private_ips: true command: docker `, - wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets", + wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi', 'a2a' or remote MCP toolsets", }, { name: "empty allowed_domains entry is rejected", diff --git a/pkg/httpclient/safeclient.go b/pkg/httpclient/safeclient.go index 763bfa4fb..fc1e60f4b 100644 --- a/pkg/httpclient/safeclient.go +++ b/pkg/httpclient/safeclient.go @@ -5,6 +5,14 @@ import ( "time" ) +// DefaultToolHTTPTimeout is the HTTP client timeout used by the built-in +// HTTP-based toolsets (`fetch`, `api`, `openapi`, `a2a`) when the operator +// does not override it via `timeout:` in the agent config. +// +// Centralised so the four toolsets agree on a single default — changing +// this value uniformly affects every HTTP-based built-in tool. +const DefaultToolHTTPTimeout = 30 * time.Second + // NewSafeClient returns the HTTP client used by built-in tools that issue // outbound calls to URLs the operator (or a fetched OpenAPI spec) supplies. // diff --git a/pkg/tools/a2a/a2a.go b/pkg/tools/a2a/a2a.go index 8f5f671c4..161a3e1c8 100644 --- a/pkg/tools/a2a/a2a.go +++ b/pkg/tools/a2a/a2a.go @@ -8,8 +8,10 @@ import ( "errors" "fmt" "log/slog" + "net/http" "strings" "sync" + "time" "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2aclient" @@ -25,12 +27,31 @@ import ( // Toolset implements tools.ToolSet for A2A remote agents. type Toolset struct { - name string - url string - headers map[string]string - client *a2aclient.Client - card *a2a.AgentCard - mu sync.RWMutex + name string + url string + headers map[string]string + timeout time.Duration + allowPrivateIPs bool + client *a2aclient.Client + card *a2a.AgentCard + mu sync.RWMutex +} + +// Option configures a Toolset. +type Option func(*Toolset) + +// WithTimeout overrides the default HTTP client timeout (see +// [httpclient.DefaultToolHTTPTimeout]) used both for fetching the agent +// card and for streaming messages. +func WithTimeout(d time.Duration) Option { + return func(t *Toolset) { t.timeout = d } +} + +// WithAllowPrivateIPs disables SSRF dial-time protection so the a2a tool +// can reach internal services. Off by default; matches the behaviour of +// the same flag on `fetch`, `api`, `openapi` and remote `mcp`. +func WithAllowPrivateIPs(allow bool) Option { + return func(t *Toolset) { t.allowPrivateIPs = allow } } // Verify interface compliance @@ -44,16 +65,29 @@ var ( func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { expander := js.NewJsExpander(runConfig.EnvProvider()) headers := expander.ExpandMap(ctx, toolset.Headers) - return NewToolset(toolset.Name, toolset.URL, headers), nil + + var opts []Option + if toolset.Timeout > 0 { + opts = append(opts, WithTimeout(time.Duration(toolset.Timeout)*time.Second)) + } + if toolset.AllowPrivateIPsEnabled() { + opts = append(opts, WithAllowPrivateIPs(true)) + } + return NewToolset(toolset.Name, toolset.URL, headers, opts...), nil } // NewToolset creates a new A2A toolset for the given URL. -func NewToolset(name, url string, headers map[string]string) *Toolset { - return &Toolset{ +func NewToolset(name, url string, headers map[string]string, opts ...Option) *Toolset { + t := &Toolset{ name: name, url: url, headers: headers, + timeout: httpclient.DefaultToolHTTPTimeout, + } + for _, opt := range opts { + opt(t) } + return t } // Instructions returns instructions for using the A2A toolset. @@ -124,19 +158,30 @@ func (t *Toolset) Tools(_ context.Context) ([]tools.Tool, error) { // Start connects to the A2A agent and fetches the agent card. func (t *Toolset) Start(ctx context.Context) error { - slog.DebugContext(ctx, "Starting A2A toolset", "url", t.url) - - card, err := agentcard.DefaultResolver.Resolve(ctx, t.url) + slog.DebugContext(ctx, "Starting A2A toolset", "url", t.url, "timeout", t.timeout, "allow_private_ips", t.allowPrivateIPs) + + // Use the SSRF-safe client to fetch the agent card so a malicious or + // misconfigured `url:` cannot reach loopback / RFC1918 / link-local + // addresses (cloud metadata at 169.254.169.254 in particular). The + // `allow_private_ips: true` opt-in disables this for legitimate + // internal-service use. + resolver := agentcard.NewResolver(httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs)) + card, err := resolver.Resolve(ctx, t.url) if err != nil { return fmt.Errorf("failed to fetch A2A agent card: %w", err) } - // Use a longer timeout for the HTTP client since LLM responses can take a while. - // The default a2a-go HTTP client has only a 5-second timeout which is too short. - httpClient := httpclient.NewHTTPClient(ctx) - httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers) + httpClient := httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs) + base := httpClient.Transport + if base == nil { + base = http.DefaultTransport + } + httpClient.Transport = upstream.NewHeaderTransport(base, t.headers) - client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient)) + client, err := a2aclient.NewFromCard(ctx, card, + a2aclient.WithDefaultsDisabled(), + a2aclient.WithJSONRPCTransport(httpClient), + ) if err != nil { return fmt.Errorf("failed to create A2A client: %w", err) } diff --git a/pkg/tools/a2a/a2a_test.go b/pkg/tools/a2a/a2a_test.go new file mode 100644 index 000000000..64357960e --- /dev/null +++ b/pkg/tools/a2a/a2a_test.go @@ -0,0 +1,86 @@ +package a2a + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + goa2a "github.com/a2aproject/a2a-go/a2a" + "github.com/a2aproject/a2a-go/a2asrv" + "github.com/a2aproject/a2a-go/a2asrv/eventqueue" + + "github.com/docker/docker-agent/pkg/tools" +) + +func TestToolSetRejectsPrivateIPForAgentCard(t *testing.T) { + t.Parallel() + + toolSet := NewToolset("test", "http://127.0.0.1/.well-known/agent-card.json", nil) + + if err := toolSet.Start(t.Context()); err == nil { + t.Fatal("Start() expected error") + } +} + +func TestToolSetStreamingWithAllowPrivateIPs(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(a2asrv.NewJSONRPCHandler(a2asrv.NewHandler(testA2AHandler{}))) + t.Cleanup(server.Close) + + cardServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(goa2a.AgentCard{ + Name: "test", + Description: "test", + URL: server.URL, + Version: "1.0.0", + ProtocolVersion: string(goa2a.Version), + PreferredTransport: goa2a.TransportProtocolJSONRPC, + Capabilities: goa2a.AgentCapabilities{Streaming: true}, + DefaultInputModes: []string{"text/plain"}, + DefaultOutputModes: []string{"text/plain"}, + Skills: []goa2a.AgentSkill{{ + ID: "test", + Name: "test", + Description: "test", + Tags: []string{"test"}, + }}, + }) + })) + t.Cleanup(cardServer.Close) + + toolSet := NewToolset("test", cardServer.URL, nil, WithAllowPrivateIPs(true)) + + if err := toolSet.Start(t.Context()); err != nil { + t.Fatalf("Start() error = %v", err) + } + + toolList, err := toolSet.Tools(t.Context()) + if err != nil { + t.Fatalf("Tools() error = %v", err) + } + if len(toolList) != 1 { + t.Fatalf("Tools() returned %d tools, want 1", len(toolList)) + } + + result, err := toolList[0].Handler(t.Context(), tools.ToolCall{Function: tools.FunctionCall{Arguments: `{"message":"hello"}`}}) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + if result == nil || result.Output != "ok" { + t.Fatalf("Handler() result = %+v, want output %q", result, "ok") + } +} + +type testA2AHandler struct{} + +func (testA2AHandler) Execute(ctx context.Context, reqCtx *a2asrv.RequestContext, queue eventqueue.Queue) error { + return queue.Write(ctx, goa2a.NewMessageForTask(goa2a.MessageRoleAgent, reqCtx, goa2a.TextPart{Text: "ok"})) +} + +func (testA2AHandler) Cancel(context.Context, *a2asrv.RequestContext, eventqueue.Queue) error { + return nil +} diff --git a/pkg/tools/builtin/api/api.go b/pkg/tools/builtin/api/api.go index 9792e6d5b..d9683cdd6 100644 --- a/pkg/tools/builtin/api/api.go +++ b/pkg/tools/builtin/api/api.go @@ -29,8 +29,6 @@ type ToolSet struct { allowPrivateIPs bool } -const defaultHTTPTimeout = 30 * time.Second - // Verify interface compliance var ( _ tools.ToolSet = (*ToolSet)(nil) @@ -114,7 +112,8 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi // Option configures an api ToolSet. type Option func(*ToolSet) -// WithTimeout overrides the default 30s HTTP client timeout. +// WithTimeout overrides the default HTTP client timeout (see +// [httpclient.DefaultToolHTTPTimeout]). func WithTimeout(d time.Duration) Option { return func(t *ToolSet) { t.timeout = d } } @@ -131,7 +130,7 @@ func New(apiConfig latest.APIToolConfig, expander *js.Expander, opts ...Option) t := &ToolSet{ config: apiConfig, expander: expander, - timeout: defaultHTTPTimeout, + timeout: httpclient.DefaultToolHTTPTimeout, } for _, opt := range opts { opt(t) diff --git a/pkg/tools/builtin/fetch/fetch.go b/pkg/tools/builtin/fetch/fetch.go index 3b11ab10d..87772629e 100644 --- a/pkg/tools/builtin/fetch/fetch.go +++ b/pkg/tools/builtin/fetch/fetch.go @@ -480,7 +480,7 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi func New(options ...ToolOption) *ToolSet { tool := &ToolSet{ handler: &fetchHandler{ - timeout: 30 * time.Second, + timeout: httpclient.DefaultToolHTTPTimeout, }, } diff --git a/pkg/tools/builtin/openapi/openapi.go b/pkg/tools/builtin/openapi/openapi.go index 10d9edb5f..d9343be19 100644 --- a/pkg/tools/builtin/openapi/openapi.go +++ b/pkg/tools/builtin/openapi/openapi.go @@ -27,8 +27,6 @@ import ( "github.com/docker/docker-agent/pkg/useragent" ) -const defaultHTTPTimeout = 30 * time.Second - // CreateToolSet is used by the tools registry. func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { expander := js.NewJsExpander(runConfig.EnvProvider()) @@ -64,8 +62,9 @@ var ( // Option configures an openapi ToolSet. type Option func(*ToolSet) -// WithTimeout overrides the default 30s HTTP client timeout used both for -// fetching the spec and for the generated tools' HTTP calls. +// WithTimeout overrides the default HTTP client timeout (see +// [httpclient.DefaultToolHTTPTimeout]) used both for fetching the spec +// and for the generated tools' HTTP calls. func WithTimeout(d time.Duration) Option { return func(t *ToolSet) { t.timeout = d } } @@ -83,7 +82,7 @@ func New(specURL string, headers map[string]string, opts ...Option) *ToolSet { t := &ToolSet{ specURL: specURL, headers: headers, - timeout: defaultHTTPTimeout, + timeout: httpclient.DefaultToolHTTPTimeout, } for _, opt := range opts { opt(t)