diff --git a/apps/vscode-e2e/AGENTS.md b/apps/vscode-e2e/AGENTS.md index 5e318436c2..8d9a4defac 100644 --- a/apps/vscode-e2e/AGENTS.md +++ b/apps/vscode-e2e/AGENTS.md @@ -160,6 +160,26 @@ ZAI_API_KEY= TEST_FILE=zai.test pnpm --filter @roo-code/vscode-e2e test:ci When adding a new test to this suite, add a matching fixture to the `installZAiFetchInterceptor` call in `suiteSetup`. Use a short unique prefix (e.g. `"zai-glm-e2e-mytest:"`) that won't appear in ``. +### Gemini (`suite/providers/gemini.test.ts`) + +Gemini routes through aimock via `googleGeminiBaseUrl: aimockUrl`. aimock has native Gemini SSE support and can proxy to `https://generativelanguage.googleapis.com` in record mode. The model ID defaults to `gemini-3-flash-preview` but can be overridden via `GEMINI_MODEL_ID`. + +The test only runs when aimock is active (replay or record). Live runs without aimock are not supported because `GEMINI_MODEL_ID` must match the fixture. + +**Record** (refresh fixtures from the real Gemini API): + +```sh +GEMINI_API_KEY= TEST_FILE=providers/gemini.test pnpm --filter @roo-code/vscode-e2e test:record +``` + +After recording, inspect the generated `fixtures/gemini-*.json`, extract the response blocks into `fixtures/gemini.json`, then delete the raw files. + +**Verify in mock mode** (no key needed): + +```sh +TEST_FILE=providers/gemini.test pnpm --filter @roo-code/vscode-e2e test:ci:mock +``` + ### xAI Grok (`suite/providers/xai.test.ts`) xAI uses the **Responses API** (`POST https://api.x.ai/v1/responses`), which is not OpenAI-compatible. aimock can't intercept it. The suite instead patches `globalThis.fetch` to intercept requests to that endpoint. By default it replays hand-crafted SSE events; when a local `fixtures/xai.json` recording exists, it can replay recorded real-API SSE events for reference. diff --git a/apps/vscode-e2e/fixtures/gemini.json b/apps/vscode-e2e/fixtures/gemini.json new file mode 100644 index 0000000000..c5cafde45f --- /dev/null +++ b/apps/vscode-e2e/fixtures/gemini.json @@ -0,0 +1,49 @@ +{ + "fixtures": [ + { + "match": { + "model": "gemini-3-flash-preview", + "userMessage": "gemini-e2e:reasoning-high: what is 2+2? Reply with only the number." + }, + "response": { + "toolCalls": [ + { + "name": "attempt_completion", + "arguments": "{\"result\":\"4\"}", + "id": "call_gemini_reasoning_high_done" + } + ] + } + }, + { + "match": { + "model": "gemini-3-flash-preview", + "userMessage": "gemini-e2e:reasoning-low: what is 2+2? Reply with only the number." + }, + "response": { + "toolCalls": [ + { + "name": "attempt_completion", + "arguments": "{\"result\":\"4\"}", + "id": "call_gemini_reasoning_low_done" + } + ] + } + }, + { + "match": { + "model": "gemini-3-flash-preview", + "userMessage": "gemini-e2e:reasoning-disable: what is 2+2? Reply with only the number." + }, + "response": { + "toolCalls": [ + { + "name": "attempt_completion", + "arguments": "{\"result\":\"4\"}", + "id": "call_gemini_reasoning_disable_done" + } + ] + } + } + ] +} diff --git a/apps/vscode-e2e/src/runTest.ts b/apps/vscode-e2e/src/runTest.ts index e1ba4ada77..ca968af06e 100644 --- a/apps/vscode-e2e/src/runTest.ts +++ b/apps/vscode-e2e/src/runTest.ts @@ -31,12 +31,17 @@ async function main() { const testGrep = getCliFlagValue("--grep") || process.env.TEST_GREP const testFile = getCliFlagValue("--file") || process.env.TEST_FILE const isDeepSeekTest = isDeepSeekTargetedRun(testFile, testGrep) + const isGeminiTest = testFile?.toLowerCase().includes("gemini.test") ?? false if (isRecord && isDeepSeekTest && !process.env.DEEPSEEK_API_KEY) { throw new Error("AIMOCK_RECORD=true requires DEEPSEEK_API_KEY to record DeepSeek fixtures") } - if (isRecord && !isDeepSeekTest && !process.env.OPENROUTER_API_KEY) { + if (isRecord && isGeminiTest && !process.env.GEMINI_API_KEY && !process.env.GOOGLE_API_KEY) { + throw new Error("AIMOCK_RECORD=true requires GEMINI_API_KEY to record Gemini fixtures") + } + + if (isRecord && !isDeepSeekTest && !isGeminiTest && !process.env.OPENROUTER_API_KEY) { throw new Error("AIMOCK_RECORD=true requires OPENROUTER_API_KEY to record fixtures") } @@ -78,6 +83,8 @@ async function main() { openai: isDeepSeekTest ? "https://api.deepseek.com" : "https://openrouter.ai/api", // aimock forwards the x-api-key header from the Anthropic SDK to the real API. anthropic: "https://api.anthropic.com", + // aimock forwards the x-goog-api-key header from the Google AI SDK. + ...(isGeminiTest && { gemini: "https://generativelanguage.googleapis.com" }), }, fixturePath: fixturesDir, }, diff --git a/apps/vscode-e2e/src/suite/providers/gemini.test.ts b/apps/vscode-e2e/src/suite/providers/gemini.test.ts new file mode 100644 index 0000000000..7355c1cefd --- /dev/null +++ b/apps/vscode-e2e/src/suite/providers/gemini.test.ts @@ -0,0 +1,288 @@ +import * as assert from "assert" + +import { RooCodeEventName, type ClineMessage } from "@roo-code/types" + +import { setDefaultSuiteTimeout } from "../test-utils" +import { waitUntilCompleted } from "../utils" + +const GEMINI_API_KEY = process.env.GEMINI_API_KEY ?? process.env.GOOGLE_API_KEY +const GEMINI_MODEL_ID = process.env.GEMINI_MODEL_ID ?? "gemini-3-flash-preview" + +type FunctionDeclaration = { + name: string + parametersJsonSchema?: Record +} + +type GeminiToolConfig = { + functionCallingConfig?: { + mode?: string + allowedFunctionNames?: string[] + } +} + +type CapturedGeminiRequest = { + model?: string + lastUserMessage: string + thinkingConfig?: Record + toolConfig?: GeminiToolConfig + hasTools: boolean + toolDeclarationCount: number + functionDeclarations: FunctionDeclaration[] +} + +function findInvalidSchemaPatterns(schema: unknown, path = ""): string[] { + if (!schema || typeof schema !== "object" || Array.isArray(schema)) { + return [] + } + + const obj = schema as Record + const violations: string[] = [] + + if ("additionalProperties" in obj) { + violations.push(`${path}.additionalProperties (stripped for Gemini compatibility)`) + } + + if ("default" in obj) { + violations.push(`${path}.default (stripped for Gemini compatibility)`) + } + + if ("$schema" in obj) { + violations.push(`${path}.$schema (JSON Schema metadata stripped for Gemini compatibility)`) + } + + if ("type" in obj && Array.isArray(obj.type)) { + violations.push(`${path}.type is an array ${JSON.stringify(obj.type)} (Gemini requires a single string type)`) + } + + for (const [key, value] of Object.entries(obj)) { + if (key === "properties" && value && typeof value === "object") { + for (const [propName, propSchema] of Object.entries(value as Record)) { + violations.push(...findInvalidSchemaPatterns(propSchema, `${path}.properties.${propName}`)) + } + } else if (key === "items") { + violations.push(...findInvalidSchemaPatterns(value, `${path}.items`)) + } else if (key === "anyOf" || key === "oneOf" || key === "allOf") { + violations.push(`${path}.${key} (collapsed for Gemini compatibility)`) + if (Array.isArray(value)) { + value.forEach((item, i) => violations.push(...findInvalidSchemaPatterns(item, `${path}.${key}[${i}]`))) + } + } + } + + return violations +} + +function getRequestUrl(input: RequestInfo | URL): string { + return typeof input === "string" ? input : input instanceof URL ? input.href : (input as Request).url +} + +function isUrlWithOrigin(rawUrl: string, expectedOrigin: string): boolean { + try { + return new URL(rawUrl).origin === expectedOrigin + } catch { + return false + } +} + +function isGeminiGenerateContentUrl(rawUrl: string): boolean { + try { + const pathname = new URL(rawUrl).pathname + return pathname.includes(":streamGenerateContent") || pathname.includes(":generateContent") + } catch { + return false + } +} + +function extractGeminiModel(rawUrl: string): string | undefined { + try { + const pathname = new URL(rawUrl).pathname + const match = pathname.match(/\/models\/([^:]+):(streamGenerateContent|generateContent)$/) + return match?.[1] + } catch { + return undefined + } +} + +function extractLastUserMessage( + contents?: Array<{ + role?: string + parts?: Array<{ text?: string }> + }>, +): string { + const lastUser = [...(contents ?? [])].reverse().find((content) => content.role === "user") + + if (!lastUser?.parts) { + return "" + } + + return lastUser.parts + .map((part) => (typeof part?.text === "string" ? part.text : JSON.stringify(part ?? ""))) + .join("") +} + +function installGeminiRequestCapture(capture: CapturedGeminiRequest[], baseUrl: string): () => void { + const originalFetch = globalThis.fetch + const targetOrigin = new URL(baseUrl).origin + + globalThis.fetch = async function (input: RequestInfo | URL, init?: RequestInit): Promise { + const url = getRequestUrl(input) + + if (isUrlWithOrigin(url, targetOrigin) && isGeminiGenerateContentUrl(url)) { + const body = init?.body && typeof init.body === "string" ? JSON.parse(init.body) : {} + const tools = Array.isArray(body.tools) ? body.tools : [] + const functionDeclarations: FunctionDeclaration[] = tools.flatMap( + (tool: { functionDeclarations?: FunctionDeclaration[] }) => + Array.isArray(tool.functionDeclarations) ? tool.functionDeclarations : [], + ) + + capture.push({ + model: extractGeminiModel(url), + lastUserMessage: extractLastUserMessage(body.contents), + thinkingConfig: + body.generationConfig && typeof body.generationConfig === "object" + ? (body.generationConfig.thinkingConfig as Record | undefined) + : undefined, + toolConfig: + body.toolConfig && typeof body.toolConfig === "object" + ? (body.toolConfig as GeminiToolConfig) + : undefined, + hasTools: tools.length > 0, + toolDeclarationCount: functionDeclarations.length, + functionDeclarations, + }) + } + + return originalFetch.call(globalThis, input, init as RequestInit) + } as typeof globalThis.fetch + + return () => { + globalThis.fetch = originalFetch + } +} + +suite("Gemini provider", function () { + setDefaultSuiteTimeout(this) + + let restoreFetch: (() => void) | undefined + const requests: CapturedGeminiRequest[] = [] + + setup(function () { + const aimockUrl = process.env.AIMOCK_URL + const isReplay = aimockUrl && process.env.AIMOCK_RECORD !== "true" + const isRecordRun = aimockUrl && process.env.AIMOCK_RECORD === "true" && !!GEMINI_API_KEY + // Live runs without aimock are not supported — GEMINI_MODEL_ID must match the fixture. + if (!isReplay && !isRecordRun) { + this.skip() + } + }) + + suiteSetup(() => { + restoreFetch = installGeminiRequestCapture( + requests, + process.env.AIMOCK_URL || "https://generativelanguage.googleapis.com", + ) + }) + + suiteTeardown(async () => { + restoreFetch?.() + restoreFetch = undefined + + const aimockUrl = process.env.AIMOCK_URL + const isRecord = process.env.AIMOCK_RECORD === "true" + await globalThis.api.setConfiguration({ + apiProvider: "openrouter" as const, + openRouterApiKey: aimockUrl && !isRecord ? "mock-key" : process.env.OPENROUTER_API_KEY!, + openRouterModelId: "openai/gpt-4.1", + ...(aimockUrl && { openRouterBaseUrl: `${aimockUrl}/v1` }), + }) + }) + + for (const reasoningEffort of ["high", "low", "disable"] as const) { + test(`Should complete a task end-to-end using ${GEMINI_MODEL_ID} via Gemini provider with reasoning effort "${reasoningEffort}"`, async () => { + requests.length = 0 + + const api = globalThis.api + const aimockUrl = process.env.AIMOCK_URL + const isRecord = process.env.AIMOCK_RECORD === "true" + const promptTag = `gemini-e2e:reasoning-${reasoningEffort}` + + await api.setConfiguration({ + apiProvider: "gemini" as const, + geminiApiKey: aimockUrl && !isRecord ? "mock-key" : GEMINI_API_KEY!, + apiModelId: GEMINI_MODEL_ID, + enableReasoningEffort: reasoningEffort !== "disable", + reasoningEffort: reasoningEffort, + ...(aimockUrl && { googleGeminiBaseUrl: aimockUrl }), + }) + + const messages: ClineMessage[] = [] + const messageHandler = ({ message }: { message: ClineMessage }) => { + if (message.type === "say" && message.partial === false) { + messages.push(message) + } + } + + api.on(RooCodeEventName.Message, messageHandler) + + try { + const taskId = await api.startNewTask({ + configuration: { mode: "ask", alwaysAllowModeSwitch: true, autoApprovalEnabled: true }, + text: `${promptTag}: what is 2+2? Reply with only the number.`, + }) + + await waitUntilCompleted({ api, taskId }) + } finally { + api.off(RooCodeEventName.Message, messageHandler) + } + + const firstRequest = requests.find((request) => request.lastUserMessage.includes(promptTag)) + assert.ok(firstRequest, "Gemini provider should issue a generate content request for the task prompt") + assert.strictEqual(firstRequest.model, GEMINI_MODEL_ID) + assert.ok(firstRequest.hasTools, "Gemini provider should include tool declarations in the request") + assert.ok( + firstRequest.toolDeclarationCount > 0, + "Gemini provider should declare at least one callable tool", + ) + assert.strictEqual( + firstRequest.toolConfig?.functionCallingConfig?.allowedFunctionNames, + undefined, + "Gemini requests should not send allowedFunctionNames; the Gemini backend returns generic INVALID_ARGUMENT for larger or history-incompatible restriction lists", + ) + + // Verify tool schemas are sanitized for Gemini compatibility. Gemini documents + // function declaration schemas as a selected OpenAPI-style subset with + // single-value `type` plus `nullable`; live testing also showed opaque + // INVALID_ARGUMENT failures from broader third-party MCP schema metadata. + for (const decl of firstRequest.functionDeclarations) { + const violations = findInvalidSchemaPatterns( + decl.parametersJsonSchema, + `${decl.name}.parametersJsonSchema`, + ) + assert.strictEqual( + violations.length, + 0, + `Tool "${decl.name}" has Gemini-incompatible schema: ${violations.join("; ")}`, + ) + } + + if (reasoningEffort === "disable") { + assert.strictEqual( + firstRequest.thinkingConfig, + undefined, + "Reasoning-disabled Gemini requests should omit thinkingConfig", + ) + } else { + assert.ok( + firstRequest.thinkingConfig, + `Gemini requests with reasoningEffort="${reasoningEffort}" should include thinkingConfig`, + ) + } + + const completionMessage = messages.find( + ({ say, text }) => (say === "completion_result" || say === "text") && text?.trim() === "4", + ) + + assert.ok(completionMessage, "Task should complete with the expected Gemini provider response") + }) + } +}) diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index 7f157570e8..110f60289c 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -190,7 +190,7 @@ describe("GeminiHandler backend support", () => { }, ] - it("should pass allowedFunctionNames to toolConfig when provided", async () => { + it("should ignore allowedFunctionNames because Gemini rejects larger restriction lists", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions @@ -208,15 +208,10 @@ describe("GeminiHandler backend support", () => { .next() const config = stub.mock.calls[0][0].config - expect(config.toolConfig).toEqual({ - functionCallingConfig: { - mode: FunctionCallingConfigMode.ANY, - allowedFunctionNames: ["read_file", "write_to_file"], - }, - }) + expect(config.toolConfig).toBeUndefined() }) - it("should include all tools but restrict callable functions via allowedFunctionNames", async () => { + it("should include all tools when allowedFunctionNames is provided", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions @@ -236,11 +231,78 @@ describe("GeminiHandler backend support", () => { const config = stub.mock.calls[0][0].config // All tools should be passed to the model expect(config.tools[0].functionDeclarations).toHaveLength(3) - // But only read_file should be allowed to be called - expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + expect(config.toolConfig).toBeUndefined() }) - it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => { + it("should not pass large allowedFunctionNames lists to Gemini", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + const manyTools = Array.from({ length: 30 }, (_, index) => ({ + type: "function" as const, + function: { + name: `tool_${index}`, + description: `Tool ${index}`, + parameters: { type: "object", properties: {} }, + }, + })) + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: manyTools, + allowedFunctionNames: manyTools.map((tool) => tool.function.name), + }) + .next() + + const config = stub.mock.calls[0][0].config + expect(config.tools[0].functionDeclarations).toHaveLength(30) + expect(config.toolConfig).toBeUndefined() + }) + + it("should not pass allowedFunctionNames even when history includes tool calls", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + const manyTools = Array.from({ length: 30 }, (_, index) => ({ + type: "function" as const, + function: { + name: `tool_${index}`, + description: `Tool ${index}`, + parameters: { type: "object", properties: {} }, + }, + })) + const messages = [ + { + role: "assistant", + content: [{ type: "tool_use", id: "tool-call-29", name: "tool_29", input: {} }], + }, + ] + + await handler + .createMessage("test", messages as any, { + taskId: "test-task", + tools: manyTools, + allowedFunctionNames: manyTools.slice(0, 29).map((tool) => tool.function.name), + }) + .next() + + const config = stub.mock.calls[0][0].config + expect(config.tools[0].functionDeclarations).toHaveLength(30) + expect(config.toolConfig).toBeUndefined() + }) + + it("should fall back to tool_choice when allowedFunctionNames is provided", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions @@ -259,9 +321,8 @@ describe("GeminiHandler backend support", () => { .next() const config = stub.mock.calls[0][0].config - // allowedFunctionNames should take precedence - mode should be ANY, not AUTO - expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY) - expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO) + expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined() }) it("should fall back to tool_choice when allowedFunctionNames is empty", async () => { @@ -309,4 +370,358 @@ describe("GeminiHandler backend support", () => { expect(config.toolConfig).toBeUndefined() }) }) + + describe("Gemini schema compatibility", () => { + it("should strip broad JSON Schema metadata from function declarations", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "mcp_tool", + description: "MCP tool", + parameters: { + $schema: "https://json-schema.org/draft/2020-12/schema", + type: "object", + additionalProperties: false, + default: {}, + properties: { + query: { + type: "string", + default: "", + }, + options: { + type: "object", + additionalProperties: true, + properties: { + limit: { type: "integer", default: 10 }, + }, + }, + }, + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + expect(JSON.stringify(schema)).not.toContain("additionalProperties") + expect(JSON.stringify(schema)).not.toContain('"default"') + expect(JSON.stringify(schema)).not.toContain("$schema") + expect(schema).toEqual({ + type: "object", + properties: { + query: { type: "string" }, + options: { + type: "object", + properties: { + limit: { type: "integer" }, + }, + }, + }, + }) + }) + + it("should collapse composition and type arrays in function declaration schemas", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "union_tool", + description: "Union tool", + parameters: { + type: "object", + properties: { + value: { + anyOf: [{ type: "string", description: "A value" }, { type: "null" }], + }, + mode: { + type: ["string", "null"], + enum: ["fast", "safe", null], + }, + config: { + allOf: [ + { type: "object", properties: { enabled: { type: "boolean" } } }, + { description: "Config object" }, + ], + }, + }, + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + expect(JSON.stringify(schema)).not.toContain("anyOf") + expect(JSON.stringify(schema)).not.toContain("oneOf") + expect(JSON.stringify(schema)).not.toContain("allOf") + expect(Array.isArray(schema.properties.mode.type)).toBe(false) + expect(schema).toEqual({ + type: "object", + properties: { + value: { type: "string", description: "A value", nullable: true }, + mode: { type: "string", enum: ["fast", "safe", null], nullable: true }, + config: { + type: "object", + properties: { enabled: { type: "boolean" } }, + description: "Config object", + }, + }, + }) + }) + + it("should deep-merge allOf fragments instead of overwriting earlier properties", async () => { + const options = { apiProvider: "gemini" } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "multi_allof_tool", + description: "Tool with multi-fragment allOf", + parameters: { + allOf: [ + { + type: "object", + properties: { a: { type: "string" } }, + required: ["a"], + }, + { + type: "object", + properties: { b: { type: "integer" } }, + required: ["b"], + }, + ], + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + // Both property blocks must survive the merge — previously `b` overwrote `a` + expect(schema.properties).toEqual({ + a: { type: "string" }, + b: { type: "integer" }, + }) + expect(schema.required).toEqual(expect.arrayContaining(["a", "b"])) + }) + + it("should resolve $ref entries before dropping $defs", async () => { + const options = { apiProvider: "gemini" } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "ref_tool", + description: "Tool with $ref", + parameters: { + type: "object", + $defs: { + Config: { + type: "object", + properties: { timeout: { type: "integer" } }, + required: ["timeout"], + }, + }, + properties: { + cfg: { $ref: "#/$defs/Config" }, + name: { type: "string" }, + }, + required: ["cfg", "name"], + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + // $defs must be gone, $ref must be inlined + expect(JSON.stringify(schema)).not.toContain("$defs") + expect(JSON.stringify(schema)).not.toContain("$ref") + expect(schema.properties.cfg).toEqual({ + type: "object", + properties: { timeout: { type: "integer" } }, + required: ["timeout"], + }) + expect(schema.properties.name).toEqual({ type: "string" }) + }) + + it("should preserve top-level properties and required entries when allOf is also present", async () => { + const options = { apiProvider: "gemini" } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "mixed_allof_tool", + description: "Tool with top-level and allOf schema fragments", + parameters: { + type: "object", + properties: { a: { type: "string" } }, + required: ["a"], + allOf: [ + { + type: "object", + properties: { b: { type: "integer" } }, + required: ["b"], + }, + ], + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + expect(schema.properties).toEqual({ + a: { type: "string" }, + b: { type: "integer" }, + }) + expect(schema.required).toEqual(expect.arrayContaining(["a", "b"])) + }) + + it("should stop recursive $ref expansion before the sanitized schema becomes cyclic", async () => { + const options = { apiProvider: "gemini" } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "recursive_ref_tool", + description: "Tool with recursive $ref", + parameters: { + type: "object", + $defs: { + Node: { + type: "object", + properties: { + value: { type: "string" }, + next: { $ref: "#/$defs/Node" }, + }, + required: ["value"], + }, + }, + properties: { + root: { $ref: "#/$defs/Node" }, + }, + required: ["root"], + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + expect(() => JSON.stringify(schema)).not.toThrow() + expect(JSON.stringify(schema)).not.toContain("$ref") + expect(schema.properties.root).toEqual({ + type: "object", + properties: { + value: { type: "string" }, + next: {}, + }, + required: ["value"], + }) + }) + + it("should preserve parameter names that collide with stripped schema keywords", async () => { + const options = { apiProvider: "gemini" } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "keyword_param_tool", + description: "Tool whose parameter names match JSON Schema keywords", + parameters: { + type: "object", + properties: { + default: { type: "string" }, + additionalProperties: { type: "boolean" }, + $schema: { type: "string" }, + normal: { type: "integer" }, + }, + required: ["default", "additionalProperties"], + }, + }, + }, + ], + }) + .next() + + const schema = stub.mock.calls[0][0].config.tools[0].functionDeclarations[0].parametersJsonSchema + expect(schema.properties).toEqual({ + default: { type: "string" }, + additionalProperties: { type: "boolean" }, + $schema: { type: "string" }, + normal: { type: "integer" }, + }) + expect(schema.required).toEqual(expect.arrayContaining(["default", "additionalProperties"])) + }) + }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index a49073ea33..3ae3f821b1 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -33,6 +33,146 @@ type GeminiHandlerOptions = ApiHandlerOptions & { isVertex?: boolean } +// Gemini documents function declaration schemas as a selected OpenAPI-style +// subset with single-value `type` plus `nullable`. In practice, third-party +// MCP schemas often include broader JSON Schema metadata/composition that has +// produced opaque INVALID_ARGUMENT responses. Keep the outbound schema narrow. +const GEMINI_SCHEMA_COMPATIBILITY_DROP_KEYS = new Set([ + "$schema", + "$id", + "$defs", + "additionalProperties", + "default", + "definitions", +]) + +function sanitizeSchemaForGemini( + schema: unknown, + defs?: Record, + activeRefs: Set = new Set(), +): unknown { + if (!schema || typeof schema !== "object") { + return schema + } + + if (Array.isArray(schema)) { + return schema.map((item) => sanitizeSchemaForGemini(item, defs, activeRefs)) + } + + const source = schema as Record + + // Extract $defs / definitions from the root schema on the first call so + // they can be used to resolve $ref entries encountered deeper in the tree. + const resolvedDefs = defs ?? ((source.$defs ?? source.definitions) as Record | undefined) + + // Resolve local JSON Pointer $ref before any other processing. + // Without this, dropping $defs leaves dangling references that Gemini rejects. + if (typeof source.$ref === "string" && resolvedDefs) { + const match = source.$ref.match(/^#\/(?:\$defs|definitions)\/(.+)$/) + if (match) { + const resolved = resolvedDefs[match[1]] + if (resolved !== undefined) { + // Recursive MCP schemas are valid JSON Schema but not something Gemini + // can consume directly. Stop at the recursive edge so we still send a + // finite, serializable schema instead of overflowing the stack. + if (activeRefs.has(match[1])) { + return {} + } + + activeRefs.add(match[1]) + try { + return sanitizeSchemaForGemini(resolved, resolvedDefs, activeRefs) + } finally { + activeRefs.delete(match[1]) + } + } + } + } + + const result: Record = {} + let nullable = source.nullable === true + + const composition = source.anyOf ?? source.oneOf + if (Array.isArray(composition)) { + const variants = composition.filter((variant) => { + return variant && typeof variant === "object" && !Array.isArray(variant) + ? (variant as Record).type !== "null" + : true + }) + nullable = nullable || variants.length < composition.length + Object.assign(result, sanitizeSchemaForGemini(variants[0] ?? {}, resolvedDefs, activeRefs)) + } + + if (Array.isArray(source.allOf)) { + for (const variant of source.allOf) { + const sanitized = sanitizeSchemaForGemini(variant, resolvedDefs, activeRefs) + if (sanitized && typeof sanitized === "object" && !Array.isArray(sanitized)) { + const s = sanitized as Record + // Deep-merge properties so later allOf fragments don't overwrite + // earlier ones (last-write-wins Object.assign drops prior keys). + if (s.properties && typeof s.properties === "object") { + result.properties = { + ...(result.properties as Record | undefined), + ...(s.properties as Record), + } + } + if (Array.isArray(s.required)) { + const existing = Array.isArray(result.required) ? (result.required as string[]) : [] + result.required = [...new Set([...existing, ...(s.required as string[])])] + } + const { properties: _p, required: _r, ...rest } = s + Object.assign(result, rest) + } + } + } + + for (const [key, value] of Object.entries(source)) { + if (GEMINI_SCHEMA_COMPATIBILITY_DROP_KEYS.has(key) || key === "anyOf" || key === "oneOf" || key === "allOf") { + continue + } + + if (key === "properties" && value && typeof value === "object" && !Array.isArray(value)) { + // Iterate the property map directly so that property names that happen + // to match schema keywords (e.g. "default", "additionalProperties") are + // preserved as-is; only each property's schema value is sanitized. + const sanitizedProperties: Record = {} + for (const [propName, propSchema] of Object.entries(value as Record)) { + sanitizedProperties[propName] = sanitizeSchemaForGemini(propSchema, resolvedDefs, activeRefs) + } + result.properties = { + ...(result.properties as Record | undefined), + ...sanitizedProperties, + } + continue + } + + if (key === "required" && Array.isArray(value)) { + const existing = Array.isArray(result.required) ? (result.required as string[]) : [] + result.required = [ + ...new Set([...existing, ...value.filter((item): item is string => typeof item === "string")]), + ] + continue + } + + if (key === "type" && Array.isArray(value)) { + const nonNullTypes = value.filter((item) => item !== "null") + if (nonNullTypes.length > 0) { + result.type = nonNullTypes[0] + } + nullable = nullable || nonNullTypes.length < value.length + continue + } + + result[key] = sanitizeSchemaForGemini(value, resolvedDefs, activeRefs) + } + + if (nullable) { + result.nullable = true + } + + return result +} + export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions @@ -132,13 +272,16 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl // Google built-in tools (Grounding, URL Context) are mutually exclusive // with function declarations in the Gemini API, so we always use // function declarations when tools are provided. + const functionDeclarations = (metadata?.tools ?? []).map((tool) => ({ + name: (tool as any).function.name, + description: (tool as any).function.description, + parametersJsonSchema: sanitizeSchemaForGemini((tool as any).function.parameters), + })) + const availableFunctionNameSet = new Set(functionDeclarations.map((declaration) => declaration.name)) + const tools: GenerateContentConfig["tools"] = [ { - functionDeclarations: (metadata?.tools ?? []).map((tool) => ({ - name: (tool as any).function.name, - description: (tool as any).function.description, - parametersJsonSchema: (tool as any).function.parameters, - })), + functionDeclarations, }, ] @@ -161,19 +304,13 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl ...(tools.length > 0 ? { tools } : {}), } - // Handle allowedFunctionNames for mode-restricted tool access. - // When provided, all tool definitions are passed to the model (so it can reference - // historical tool calls in conversation), but only the specified tools can be invoked. - // This takes precedence over tool_choice to ensure mode restrictions are honored. - if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) { - config.toolConfig = { - functionCallingConfig: { - // Use ANY mode to allow calling any of the allowed functions - mode: FunctionCallingConfigMode.ANY, - allowedFunctionNames: metadata.allowedFunctionNames, - }, - } - } else if (metadata?.tool_choice) { + // Do not pass metadata.allowedFunctionNames to Gemini. Live API testing showed + // that allowedFunctionNames triggers a generic 400 INVALID_ARGUMENT at 26 or more + // names. It can also + // reject prior function calls if their names are absent from the current + // allowed list. We still pass all declarations for history compatibility; + // mode/tool restrictions are enforced by the tool execution layer. + if (metadata?.tool_choice) { const choice = metadata.tool_choice let mode: FunctionCallingConfigMode let allowedFunctionNames: string[] | undefined @@ -186,8 +323,13 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl // "required" means the model must call at least one tool; Gemini uses ANY for this. mode = FunctionCallingConfigMode.ANY } else if (typeof choice === "object" && "function" in choice && choice.type === "function") { - mode = FunctionCallingConfigMode.ANY - allowedFunctionNames = [choice.function.name] + const selectedToolName = choice.function.name + if (availableFunctionNameSet.has(selectedToolName)) { + mode = FunctionCallingConfigMode.ANY + allowedFunctionNames = [selectedToolName] + } else { + mode = FunctionCallingConfigMode.AUTO + } } else { // Fall back to AUTO for unknown values to avoid unintentionally broadening tool access. mode = FunctionCallingConfigMode.AUTO diff --git a/src/core/tools/__tests__/validateToolUse.spec.ts b/src/core/tools/__tests__/validateToolUse.spec.ts index 29455e3688..9e4a8bbd0c 100644 --- a/src/core/tools/__tests__/validateToolUse.spec.ts +++ b/src/core/tools/__tests__/validateToolUse.spec.ts @@ -186,6 +186,14 @@ describe("mode-validator", () => { ) }) + it("blocks mode-disallowed tools even if a provider declared them", () => { + // Gemini may receive all tool declarations for history compatibility, so + // execution-time validation must remain the final mode restriction guard. + expect(() => validateToolUse("write_to_file", askMode, [])).toThrow( + 'Tool "write_to_file" is not allowed in ask mode.', + ) + }) + it("does not throw for allowed tools in architect mode", () => { expect(() => validateToolUse("read_file", "architect", [])).not.toThrow() })