Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions packages/appkit/src/agents/databricks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,44 @@ function isRecord(value: unknown): value is Record<string, unknown> {
return typeof value === "object" && value !== null;
}

/**
* Optional generation parameters forwarded to the OpenAI-compatible serving
* request body. Names match the serving API wire keys. Only keys that are set
* are sent — undefined values are omitted so the endpoint applies its own
* defaults. Ranges are not validated here; the serving endpoint validates.
*/
export interface GenerationParams {
/** Sampling temperature. */
temperature?: number;
/** Nucleus sampling probability mass (`top_p`). */
top_p?: number;
/** Stop sequence(s) that end generation. */
stop?: string | string[];
/** Penalize tokens by frequency. */
frequency_penalty?: number;
/** Penalize tokens by prior presence. */
presence_penalty?: number;
}

const GENERATION_PARAM_KEYS = [
"temperature",
"top_p",
"stop",
"frequency_penalty",
"presence_penalty",
] as const satisfies readonly (keyof GenerationParams)[];

/** Copy only the set generation params onto the request body. */
function applyGenerationParams(
body: Record<string, unknown>,
params: GenerationParams,
): void {
for (const key of GENERATION_PARAM_KEYS) {
const value = params[key];
if (value !== undefined) body[key] = value;
}
}

function extractLlamaToolJsonSlice(text: string): string | undefined {
const start = text.indexOf("[{");
if (start < 0) return undefined;
Expand Down Expand Up @@ -83,6 +121,8 @@ interface RawFetchAdapterOptions {
authenticate: () => Promise<Record<string, string>>;
maxSteps?: number;
maxTokens?: number;
/** Optional generation params forwarded to the serving request body. */
generationParams?: GenerationParams;
/** Max length of one SSE line (including an incomplete tail in the buffer). */
maxSseLineChars?: number;
/** Max total length of assistant `delta.content` across the stream. */
Expand All @@ -101,6 +141,7 @@ interface StreamBodyAdapterOptions {
streamBody: StreamBody;
maxSteps?: number;
maxTokens?: number;
generationParams?: GenerationParams;
maxSseLineChars?: number;
maxStreamTextChars?: number;
maxToolArgumentsChars?: number;
Expand Down Expand Up @@ -134,6 +175,7 @@ interface ServingEndpointOptions {
endpointName: string;
maxSteps?: number;
maxTokens?: number;
generationParams?: GenerationParams;
maxSseLineChars?: number;
maxStreamTextChars?: number;
maxToolArgumentsChars?: number;
Expand All @@ -142,6 +184,7 @@ interface ServingEndpointOptions {
interface ModelServingOptions {
maxSteps?: number;
maxTokens?: number;
generationParams?: GenerationParams;
workspaceClient?: WorkspaceClientLike;
maxSseLineChars?: number;
maxStreamTextChars?: number;
Expand Down Expand Up @@ -237,13 +280,15 @@ export class DatabricksAdapter implements AgentAdapter {
private streamBody: StreamBody;
private maxSteps: number;
private maxTokens: number;
private generationParams: GenerationParams;
private maxSseLineChars: number;
private maxStreamTextChars: number;
private maxToolArgumentsChars: number;

constructor(options: DatabricksAdapterOptions) {
this.maxSteps = options.maxSteps ?? 10;
this.maxTokens = options.maxTokens ?? 4096;
this.generationParams = options.generationParams ?? {};
this.maxSseLineChars =
options.maxSseLineChars ?? DEFAULT_MAX_SSE_LINE_CHARS;
this.maxStreamTextChars =
Expand Down Expand Up @@ -296,6 +341,7 @@ export class DatabricksAdapter implements AgentAdapter {
endpointName,
maxSteps,
maxTokens,
generationParams,
maxSseLineChars,
maxStreamTextChars,
maxToolArgumentsChars,
Expand All @@ -313,6 +359,7 @@ export class DatabricksAdapter implements AgentAdapter {
),
maxSteps,
maxTokens,
generationParams,
maxSseLineChars,
maxStreamTextChars,
maxToolArgumentsChars,
Expand Down Expand Up @@ -367,6 +414,7 @@ export class DatabricksAdapter implements AgentAdapter {
endpointName: resolvedEndpoint,
maxSteps: options?.maxSteps,
maxTokens: options?.maxTokens,
generationParams: options?.generationParams,
maxSseLineChars: options?.maxSseLineChars,
maxStreamTextChars: options?.maxStreamTextChars,
maxToolArgumentsChars: options?.maxToolArgumentsChars,
Expand Down Expand Up @@ -492,6 +540,8 @@ export class DatabricksAdapter implements AgentAdapter {
max_tokens: this.maxTokens,
};

applyGenerationParams(body, this.generationParams);

if (tools.length > 0) {
body.tools = tools;
}
Expand Down
57 changes: 56 additions & 1 deletion packages/appkit/src/agents/tests/databricks.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import type { AgentEvent, AgentToolDefinition, Message } from "shared";
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { DatabricksAdapter, parseTextToolCalls } from "../databricks";
import {
DatabricksAdapter,
type GenerationParams,
parseTextToolCalls,
} from "../databricks";

const mockAuthenticate = vi
.fn()
Expand Down Expand Up @@ -93,6 +97,7 @@ function createAdapter(overrides?: {
authenticate?: () => Promise<Record<string, string>>;
maxSteps?: number;
maxTokens?: number;
generationParams?: GenerationParams;
maxSseLineChars?: number;
maxStreamTextChars?: number;
maxToolArgumentsChars?: number;
Expand Down Expand Up @@ -566,6 +571,56 @@ describe("DatabricksAdapter", () => {
});
});

test("forwards set generation params to the request body", async () => {
globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]);

const adapter = createAdapter({
generationParams: {
temperature: 0.2,
top_p: 0.9,
stop: ["END"],
frequency_penalty: 0.5,
presence_penalty: 0.1,
},
});

for await (const _ of adapter.run(
{ messages: createTestMessages(), tools: [], threadId: "t1" },
{ executeTool: vi.fn() },
)) {
// drain
}

const [, init] = (globalThis.fetch as any).mock.calls[0];
const body = JSON.parse(init.body);
expect(body.temperature).toBe(0.2);
expect(body.top_p).toBe(0.9);
expect(body.stop).toEqual(["END"]);
expect(body.frequency_penalty).toBe(0.5);
expect(body.presence_penalty).toBe(0.1);
});

test("omits generation param keys that are not set", async () => {
globalThis.fetch = mockFetch([textDelta("Hi"), sseChunk("[DONE]")]);

const adapter = createAdapter({ generationParams: { temperature: 0.7 } });

for await (const _ of adapter.run(
{ messages: createTestMessages(), tools: [], threadId: "t1" },
{ executeTool: vi.fn() },
)) {
// drain
}

const [, init] = (globalThis.fetch as any).mock.calls[0];
const body = JSON.parse(init.body);
expect(body.temperature).toBe(0.7);
expect(body).not.toHaveProperty("top_p");
expect(body).not.toHaveProperty("stop");
expect(body).not.toHaveProperty("frequency_penalty");
expect(body).not.toHaveProperty("presence_penalty");
});

test("forwards tool thread fields from input messages to the request body", async () => {
globalThis.fetch = mockFetch([textDelta("Done"), sseChunk("[DONE]")]);

Expand Down
6 changes: 5 additions & 1 deletion packages/appkit/src/beta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ export type {
ToolAnnotations,
ToolProvider,
} from "shared";
export { DatabricksAdapter, parseTextToolCalls } from "./agents/databricks";
export {
DatabricksAdapter,
type GenerationParams,
parseTextToolCalls,
} from "./agents/databricks";

// Agent runtime
export { createAgent } from "./core/agent/create-agent";
Expand Down
35 changes: 35 additions & 0 deletions packages/appkit/src/core/agent/load-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import fs from "node:fs/promises";
import path from "node:path";
import yaml from "js-yaml";
import type { AgentAdapter } from "shared";
import type { GenerationParams } from "../../agents/databricks";
import type {
AgentDefinition,
AgentTool,
Expand Down Expand Up @@ -75,6 +76,12 @@ interface Frontmatter {
agents?: string[];
maxSteps?: number;
maxTokens?: number;
/**
* Optional OpenAI-compatible generation params forwarded to the serving
* request body (`temperature`, `top_p`, `stop`, `frequency_penalty`,
* `presence_penalty`). Parsed defensively in {@link buildDefinition}.
*/
generationParams?: Record<string, unknown>;
default?: boolean;
baseSystemPrompt?: false | string;
ephemeral?: boolean;
Expand Down Expand Up @@ -120,6 +127,7 @@ const ALLOWED_KEYS = new Set([
"agents",
"maxSteps",
"maxTokens",
"generationParams",
"default",
"baseSystemPrompt",
"ephemeral",
Expand Down Expand Up @@ -340,6 +348,32 @@ export function parseFrontmatter(
return { data: data as Frontmatter, content: match[2].trim() };
}

/**
* Defensively maps a frontmatter `generationParams` map to {@link GenerationParams}.
* Picks only known keys with the expected wire types; ignores everything else.
* Returns `undefined` when no valid key is present.
*/
function parseGenerationParams(value: unknown): GenerationParams | undefined {
if (typeof value !== "object" || value === null || Array.isArray(value)) {
return undefined;
}
const raw = value as Record<string, unknown>;
const out: GenerationParams = {};
if (typeof raw.temperature === "number") out.temperature = raw.temperature;
if (typeof raw.top_p === "number") out.top_p = raw.top_p;
if (typeof raw.frequency_penalty === "number")
out.frequency_penalty = raw.frequency_penalty;
if (typeof raw.presence_penalty === "number")
out.presence_penalty = raw.presence_penalty;
if (
typeof raw.stop === "string" ||
(Array.isArray(raw.stop) && raw.stop.every((s) => typeof s === "string"))
) {
out.stop = raw.stop as string | string[];
}
return Object.keys(out).length > 0 ? out : undefined;
}

function buildDefinition(
name: string,
raw: string,
Expand All @@ -364,6 +398,7 @@ function buildDefinition(
tools: Object.keys(tools).length > 0 ? tools : undefined,
maxSteps: typeof fm.maxSteps === "number" ? fm.maxSteps : undefined,
maxTokens: typeof fm.maxTokens === "number" ? fm.maxTokens : undefined,
generationParams: parseGenerationParams(fm.generationParams),
baseSystemPrompt,
ephemeral: typeof fm.ephemeral === "boolean" ? fm.ephemeral : undefined,
};
Expand Down
11 changes: 11 additions & 0 deletions packages/appkit/src/core/agent/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
ThreadStore,
ToolAnnotations,
} from "shared";
import type { GenerationParams } from "../../agents/databricks";
import type { McpHostPolicyConfig } from "../../connectors/mcp";
import type { FunctionTool } from "./tools/function-tool";
import type { HostedTool } from "./tools/hosted-tools";
Expand Down Expand Up @@ -162,6 +163,14 @@ export interface AgentDefinition {
baseSystemPrompt?: BaseSystemPromptOption;
maxSteps?: number;
maxTokens?: number;
/**
* Optional generation parameters (`temperature`, `top_p`, `stop`,
* `frequency_penalty`, `presence_penalty`) forwarded to the OpenAI-compatible
* serving request body. Only set keys are sent. Applied only when AppKit
* builds the adapter itself (string or omitted `model`); when you pass a
* pre-built `AgentAdapter`, configure generation params on it directly.
*/
generationParams?: GenerationParams;
/**
* When true, the thread used for a chat request against this agent is
* deleted from `ThreadStore` after the stream completes (success or
Expand Down Expand Up @@ -309,6 +318,8 @@ export interface RegisteredAgent {
baseSystemPrompt?: BaseSystemPromptOption;
maxSteps?: number;
maxTokens?: number;
/** Mirrors `AgentDefinition.generationParams`. */
generationParams?: GenerationParams;
/** Mirrors `AgentDefinition.ephemeral` — skip thread persistence. */
ephemeral?: boolean;
}
Expand Down
9 changes: 8 additions & 1 deletion packages/appkit/src/plugins/agents/agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ export class AgentsPlugin extends Plugin implements ToolProvider {
baseSystemPrompt: def.baseSystemPrompt,
maxSteps: def.maxSteps,
maxTokens: def.maxTokens,
generationParams: def.generationParams,
ephemeral: def.ephemeral,
};
}
Expand All @@ -458,9 +459,15 @@ export class AgentsPlugin extends Plugin implements ToolProvider {
// Per-agent adapter knobs from `AgentDefinition` / markdown frontmatter.
// Only applied when AppKit builds the adapter itself (string or omitted
// model). Users who pass a pre-built `AgentAdapter` own these settings.
const adapterOptions: { maxSteps?: number; maxTokens?: number } = {};
const adapterOptions: {
maxSteps?: number;
maxTokens?: number;
generationParams?: AgentDefinition["generationParams"];
} = {};
if (def.maxSteps !== undefined) adapterOptions.maxSteps = def.maxSteps;
if (def.maxTokens !== undefined) adapterOptions.maxTokens = def.maxTokens;
if (def.generationParams !== undefined)
adapterOptions.generationParams = def.generationParams;

if (!source) {
const { DatabricksAdapter } = await import("../../agents/databricks");
Expand Down
Loading