From 6cf0eb4310a50e607a7230a82d7f08a4ff02a7f3 Mon Sep 17 00:00:00 2001 From: MarioCadenas Date: Sat, 13 Jun 2026 21:28:22 +0200 Subject: [PATCH] feat(agents): forward generation params (temperature/top_p/stop/...) to the serving adapter The Databricks agent adapter previously sent only max_tokens to the serving endpoint. Add optional pass-through for the standard OpenAI-compatible generation params (temperature, top_p, stop, frequency_penalty, presence_penalty), declared on the agent definition (code + markdown frontmatter) and on the DatabricksAdapter options. Only set keys are forwarded into the request body; undefined keys are omitted so the endpoint applies its own defaults. Co-authored-by: Isaac Signed-off-by: MarioCadenas --- packages/appkit/src/agents/databricks.ts | 50 ++++++++++++++++ .../src/agents/tests/databricks.test.ts | 57 ++++++++++++++++++- packages/appkit/src/beta.ts | 6 +- packages/appkit/src/core/agent/load-agents.ts | 35 ++++++++++++ packages/appkit/src/core/agent/types.ts | 11 ++++ packages/appkit/src/plugins/agents/agents.ts | 9 ++- 6 files changed, 165 insertions(+), 3 deletions(-) diff --git a/packages/appkit/src/agents/databricks.ts b/packages/appkit/src/agents/databricks.ts index 6e2e78d60..e2614592c 100644 --- a/packages/appkit/src/agents/databricks.ts +++ b/packages/appkit/src/agents/databricks.ts @@ -26,6 +26,44 @@ function isRecord(value: unknown): value is Record { return typeof value === "object" && value !== null; } +/** + * Optional generation parameters forwarded to the OpenAI-compatible serving + * request body. Names match the serving API wire keys. Only keys that are set + * are sent — undefined values are omitted so the endpoint applies its own + * defaults. Ranges are not validated here; the serving endpoint validates. + */ +export interface GenerationParams { + /** Sampling temperature. */ + temperature?: number; + /** Nucleus sampling probability mass (`top_p`). */ + top_p?: number; + /** Stop sequence(s) that end generation. */ + stop?: string | string[]; + /** Penalize tokens by frequency. */ + frequency_penalty?: number; + /** Penalize tokens by prior presence. */ + presence_penalty?: number; +} + +const GENERATION_PARAM_KEYS = [ + "temperature", + "top_p", + "stop", + "frequency_penalty", + "presence_penalty", +] as const satisfies readonly (keyof GenerationParams)[]; + +/** Copy only the set generation params onto the request body. */ +function applyGenerationParams( + body: Record, + params: GenerationParams, +): void { + for (const key of GENERATION_PARAM_KEYS) { + const value = params[key]; + if (value !== undefined) body[key] = value; + } +} + function extractLlamaToolJsonSlice(text: string): string | undefined { const start = text.indexOf("[{"); if (start < 0) return undefined; @@ -83,6 +121,8 @@ interface RawFetchAdapterOptions { authenticate: () => Promise>; maxSteps?: number; maxTokens?: number; + /** Optional generation params forwarded to the serving request body. */ + generationParams?: GenerationParams; /** Max length of one SSE line (including an incomplete tail in the buffer). */ maxSseLineChars?: number; /** Max total length of assistant `delta.content` across the stream. */ @@ -101,6 +141,7 @@ interface StreamBodyAdapterOptions { streamBody: StreamBody; maxSteps?: number; maxTokens?: number; + generationParams?: GenerationParams; maxSseLineChars?: number; maxStreamTextChars?: number; maxToolArgumentsChars?: number; @@ -134,6 +175,7 @@ interface ServingEndpointOptions { endpointName: string; maxSteps?: number; maxTokens?: number; + generationParams?: GenerationParams; maxSseLineChars?: number; maxStreamTextChars?: number; maxToolArgumentsChars?: number; @@ -142,6 +184,7 @@ interface ServingEndpointOptions { interface ModelServingOptions { maxSteps?: number; maxTokens?: number; + generationParams?: GenerationParams; workspaceClient?: WorkspaceClientLike; maxSseLineChars?: number; maxStreamTextChars?: number; @@ -237,6 +280,7 @@ export class DatabricksAdapter implements AgentAdapter { private streamBody: StreamBody; private maxSteps: number; private maxTokens: number; + private generationParams: GenerationParams; private maxSseLineChars: number; private maxStreamTextChars: number; private maxToolArgumentsChars: number; @@ -244,6 +288,7 @@ export class DatabricksAdapter implements AgentAdapter { constructor(options: DatabricksAdapterOptions) { this.maxSteps = options.maxSteps ?? 10; this.maxTokens = options.maxTokens ?? 4096; + this.generationParams = options.generationParams ?? {}; this.maxSseLineChars = options.maxSseLineChars ?? DEFAULT_MAX_SSE_LINE_CHARS; this.maxStreamTextChars = @@ -296,6 +341,7 @@ export class DatabricksAdapter implements AgentAdapter { endpointName, maxSteps, maxTokens, + generationParams, maxSseLineChars, maxStreamTextChars, maxToolArgumentsChars, @@ -313,6 +359,7 @@ export class DatabricksAdapter implements AgentAdapter { ), maxSteps, maxTokens, + generationParams, maxSseLineChars, maxStreamTextChars, maxToolArgumentsChars, @@ -367,6 +414,7 @@ export class DatabricksAdapter implements AgentAdapter { endpointName: resolvedEndpoint, maxSteps: options?.maxSteps, maxTokens: options?.maxTokens, + generationParams: options?.generationParams, maxSseLineChars: options?.maxSseLineChars, maxStreamTextChars: options?.maxStreamTextChars, maxToolArgumentsChars: options?.maxToolArgumentsChars, @@ -492,6 +540,8 @@ export class DatabricksAdapter implements AgentAdapter { max_tokens: this.maxTokens, }; + applyGenerationParams(body, this.generationParams); + if (tools.length > 0) { body.tools = tools; } diff --git a/packages/appkit/src/agents/tests/databricks.test.ts b/packages/appkit/src/agents/tests/databricks.test.ts index 84f0c6717..665f60051 100644 --- a/packages/appkit/src/agents/tests/databricks.test.ts +++ b/packages/appkit/src/agents/tests/databricks.test.ts @@ -1,6 +1,10 @@ import type { AgentEvent, AgentToolDefinition, Message } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; -import { DatabricksAdapter, parseTextToolCalls } from "../databricks"; +import { + DatabricksAdapter, + type GenerationParams, + parseTextToolCalls, +} from "../databricks"; const mockAuthenticate = vi .fn() @@ -93,6 +97,7 @@ function createAdapter(overrides?: { authenticate?: () => Promise>; maxSteps?: number; maxTokens?: number; + generationParams?: GenerationParams; maxSseLineChars?: number; maxStreamTextChars?: number; maxToolArgumentsChars?: number; @@ -566,6 +571,56 @@ describe("DatabricksAdapter", () => { }); }); + test("forwards set generation params to the request body", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter({ + generationParams: { + temperature: 0.2, + top_p: 0.9, + stop: ["END"], + frequency_penalty: 0.5, + presence_penalty: 0.1, + }, + }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + const body = JSON.parse(init.body); + expect(body.temperature).toBe(0.2); + expect(body.top_p).toBe(0.9); + expect(body.stop).toEqual(["END"]); + expect(body.frequency_penalty).toBe(0.5); + expect(body.presence_penalty).toBe(0.1); + }); + + test("omits generation param keys that are not set", async () => { + globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]); + + const adapter = createAdapter({ generationParams: { temperature: 0.7 } }); + + for await (const _ of adapter.run( + { messages: createTestMessages(), tools: [], threadId: "t1" }, + { executeTool: vi.fn() }, + )) { + // drain + } + + const [, init] = (globalThis.fetch as any).mock.calls[0]; + const body = JSON.parse(init.body); + expect(body.temperature).toBe(0.7); + expect(body).not.toHaveProperty("top_p"); + expect(body).not.toHaveProperty("stop"); + expect(body).not.toHaveProperty("frequency_penalty"); + expect(body).not.toHaveProperty("presence_penalty"); + }); + test("forwards tool thread fields from input messages to the request body", async () => { globalThis.fetch = mockFetch([textDelta("Done"), sseChunk("[DONE]")]); diff --git a/packages/appkit/src/beta.ts b/packages/appkit/src/beta.ts index 3f5bba80c..cffaf1adc 100644 --- a/packages/appkit/src/beta.ts +++ b/packages/appkit/src/beta.ts @@ -18,7 +18,11 @@ export type { ToolAnnotations, ToolProvider, } from "shared"; -export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks"; +export { + DatabricksAdapter, + type GenerationParams, + parseTextToolCalls, +} from "./agents/databricks"; // Agent runtime export { createAgent } from "./core/agent/create-agent"; diff --git a/packages/appkit/src/core/agent/load-agents.ts b/packages/appkit/src/core/agent/load-agents.ts index 13b2ff70d..910156f70 100644 --- a/packages/appkit/src/core/agent/load-agents.ts +++ b/packages/appkit/src/core/agent/load-agents.ts @@ -3,6 +3,7 @@ import fs from "node:fs/promises"; import path from "node:path"; import yaml from "js-yaml"; import type { AgentAdapter } from "shared"; +import type { GenerationParams } from "../../agents/databricks"; import type { AgentDefinition, AgentTool, @@ -75,6 +76,12 @@ interface Frontmatter { agents?: string[]; maxSteps?: number; maxTokens?: number; + /** + * Optional OpenAI-compatible generation params forwarded to the serving + * request body (`temperature`, `top_p`, `stop`, `frequency_penalty`, + * `presence_penalty`). Parsed defensively in {@link buildDefinition}. + */ + generationParams?: Record; default?: boolean; baseSystemPrompt?: false | string; ephemeral?: boolean; @@ -120,6 +127,7 @@ const ALLOWED_KEYS = new Set([ "agents", "maxSteps", "maxTokens", + "generationParams", "default", "baseSystemPrompt", "ephemeral", @@ -340,6 +348,32 @@ export function parseFrontmatter( return { data: data as Frontmatter, content: match[2].trim() }; } +/** + * Defensively maps a frontmatter `generationParams` map to {@link GenerationParams}. + * Picks only known keys with the expected wire types; ignores everything else. + * Returns `undefined` when no valid key is present. + */ +function parseGenerationParams(value: unknown): GenerationParams | undefined { + if (typeof value !== "object" || value === null || Array.isArray(value)) { + return undefined; + } + const raw = value as Record; + const out: GenerationParams = {}; + if (typeof raw.temperature === "number") out.temperature = raw.temperature; + if (typeof raw.top_p === "number") out.top_p = raw.top_p; + if (typeof raw.frequency_penalty === "number") + out.frequency_penalty = raw.frequency_penalty; + if (typeof raw.presence_penalty === "number") + out.presence_penalty = raw.presence_penalty; + if ( + typeof raw.stop === "string" || + (Array.isArray(raw.stop) && raw.stop.every((s) => typeof s === "string")) + ) { + out.stop = raw.stop as string | string[]; + } + return Object.keys(out).length > 0 ? out : undefined; +} + function buildDefinition( name: string, raw: string, @@ -364,6 +398,7 @@ function buildDefinition( tools: Object.keys(tools).length > 0 ? tools : undefined, maxSteps: typeof fm.maxSteps === "number" ? fm.maxSteps : undefined, maxTokens: typeof fm.maxTokens === "number" ? fm.maxTokens : undefined, + generationParams: parseGenerationParams(fm.generationParams), baseSystemPrompt, ephemeral: typeof fm.ephemeral === "boolean" ? fm.ephemeral : undefined, }; diff --git a/packages/appkit/src/core/agent/types.ts b/packages/appkit/src/core/agent/types.ts index cf47845f7..0dbbcfac2 100644 --- a/packages/appkit/src/core/agent/types.ts +++ b/packages/appkit/src/core/agent/types.ts @@ -5,6 +5,7 @@ import type { ThreadStore, ToolAnnotations, } from "shared"; +import type { GenerationParams } from "../../agents/databricks"; import type { McpHostPolicyConfig } from "../../connectors/mcp"; import type { FunctionTool } from "./tools/function-tool"; import type { HostedTool } from "./tools/hosted-tools"; @@ -162,6 +163,14 @@ export interface AgentDefinition { baseSystemPrompt?: BaseSystemPromptOption; maxSteps?: number; maxTokens?: number; + /** + * Optional generation parameters (`temperature`, `top_p`, `stop`, + * `frequency_penalty`, `presence_penalty`) forwarded to the OpenAI-compatible + * serving request body. Only set keys are sent. Applied only when AppKit + * builds the adapter itself (string or omitted `model`); when you pass a + * pre-built `AgentAdapter`, configure generation params on it directly. + */ + generationParams?: GenerationParams; /** * When true, the thread used for a chat request against this agent is * deleted from `ThreadStore` after the stream completes (success or @@ -309,6 +318,8 @@ export interface RegisteredAgent { baseSystemPrompt?: BaseSystemPromptOption; maxSteps?: number; maxTokens?: number; + /** Mirrors `AgentDefinition.generationParams`. */ + generationParams?: GenerationParams; /** Mirrors `AgentDefinition.ephemeral` — skip thread persistence. */ ephemeral?: boolean; } diff --git a/packages/appkit/src/plugins/agents/agents.ts b/packages/appkit/src/plugins/agents/agents.ts index c63ec094f..ce40c514e 100644 --- a/packages/appkit/src/plugins/agents/agents.ts +++ b/packages/appkit/src/plugins/agents/agents.ts @@ -446,6 +446,7 @@ export class AgentsPlugin extends Plugin implements ToolProvider { baseSystemPrompt: def.baseSystemPrompt, maxSteps: def.maxSteps, maxTokens: def.maxTokens, + generationParams: def.generationParams, ephemeral: def.ephemeral, }; } @@ -458,9 +459,15 @@ export class AgentsPlugin extends Plugin implements ToolProvider { // Per-agent adapter knobs from `AgentDefinition` / markdown frontmatter. // Only applied when AppKit builds the adapter itself (string or omitted // model). Users who pass a pre-built `AgentAdapter` own these settings. - const adapterOptions: { maxSteps?: number; maxTokens?: number } = {}; + const adapterOptions: { + maxSteps?: number; + maxTokens?: number; + generationParams?: AgentDefinition["generationParams"]; + } = {}; if (def.maxSteps !== undefined) adapterOptions.maxSteps = def.maxSteps; if (def.maxTokens !== undefined) adapterOptions.maxTokens = def.maxTokens; + if (def.generationParams !== undefined) + adapterOptions.generationParams = def.generationParams; if (!source) { const { DatabricksAdapter } = await import("../../agents/databricks");