diff --git a/packages/ai/README.md b/packages/ai/README.md index 61b5b771b..56ca19c6d 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -705,6 +705,18 @@ try { } ``` +### Execution and retries + +By default an AI task runs the provider call through a **direct** or +**concurrency-limited** execution strategy that invokes the provider **once** +and surfaces any failure to the caller — the built-in strategies do **not** +retry automatically. Provider errors are still classified (retryable vs. +permanent, with a rate-limit `retry-after` when the provider supplies one) for +diagnostics and for queue-backed execution, but that classification only drives +automatic retries when a task is run through a persistent job queue, not through +the default in-process strategies. If you need provider-level retries, run the +task on a queue-backed strategy or implement retry/backoff in your own caller. + ## Advanced Configuration ### Model Input Resolution diff --git a/packages/ai/src/capability/StreamEventAccumulator.ts b/packages/ai/src/capability/StreamEventAccumulator.ts index 21ca0110e..92833b2b9 100644 --- a/packages/ai/src/capability/StreamEventAccumulator.ts +++ b/packages/ai/src/capability/StreamEventAccumulator.ts @@ -23,8 +23,10 @@ const isNonEmptyObject = (v: unknown): v is Record => * - `finish` → captured separately via {@link observeFinish}; mandatory before * {@link materialize}. * - * Mixed-mode (both text-delta and object-delta on the same stream) is - * rejected at materialise time. + * Text and object deltas on DISTINCT ports coexist and compose into one object + * (e.g. tool-calling streams text on `text` and tool calls on `toolCalls`); + * accumulated deltas take precedence over the finish payload, mirroring the + * task-graph StreamProcessor so `.run()` and streaming produce identical output. * * This accumulator is **only** instantiated at explicit terminal-consumer * sites (AiTask.execute, StreamProcessor `ctx.shouldAccumulate` branch). Do @@ -125,13 +127,6 @@ export class StreamEventAccumulator { err.lastEventType = lastEventType; throw err; } - if (this.hasTextDeltas && this.hasObjectDeltas) { - throw new Error( - "StreamEventAccumulator: stream mixed text-delta and object-delta events. " + - "Mixed-mode streams are not supported." - ); - } - // One-shot: finish carries the complete payload. if (!this.hasTextDeltas && !this.hasObjectDeltas && !this.hasSnapshots) { if (isNonEmptyObject(this.finishData)) return this.finishData; @@ -146,18 +141,24 @@ export class StreamEventAccumulator { return this.snapshotAccumulator as T; } - // Text-delta mode — per-port map → object. - if (this.hasTextDeltas && !this.hasObjectDeltas) { - const result: Record = {}; - for (const [port, text] of this.textAccumulator) result[port] = text; - if (isNonEmptyObject(this.finishData)) Object.assign(result, this.finishData); - return result as unknown as T; + // Delta mode — text and/or object deltas. Accumulated deltas take precedence + // over the finish payload, so a streaming finish (empty `{}` or a structural + // default scaffold like `{ text: "", toolCalls: [] }`) can never clobber + // streamed content. Text and object deltas on DISTINCT ports compose (e.g. + // tool-calling streams text on `text` and tool calls on `toolCalls`); a + // same-port collision resolves object-last. Mirrors the task-graph + // StreamProcessor so `.run()` and streaming produce identical output. + const fd = this.finishData; + const result: Record = + fd !== null && typeof fd === "object" && !Array.isArray(fd) + ? { ...(fd as Record) } + : {}; + for (const [port, text] of this.textAccumulator) { + if (text.length > 0) result[port] = text; } - - // Object-delta mode — per-port object → output, then merge finish. - const merged: Record = {}; - for (const [port, obj] of this.objectAccumulator) merged[port] = obj; - if (isNonEmptyObject(this.finishData)) Object.assign(merged, this.finishData as object); - return merged as unknown as T; + for (const [port, obj] of this.objectAccumulator) { + result[port] = obj; + } + return result as unknown as T; } } diff --git a/packages/ai/src/job/AiJob.ts b/packages/ai/src/job/AiJob.ts index fb0928c69..55cbe140e 100644 --- a/packages/ai/src/job/AiJob.ts +++ b/packages/ai/src/job/AiJob.ts @@ -113,7 +113,13 @@ export function classifyProviderError(err: unknown, taskType: string, provider: : typeof (err as any)?.statusCode === "number" ? (err as any).statusCode : (() => { - const m = message.match(/\b([45]\d{2})\b/); + // Only treat a 4xx/5xx number as an HTTP status when it appears in + // an HTTP-shaped context (e.g. "HTTP 503", "status: 429"). A bare + // number like the "512" in "Sequence length 512 exceeds limit" or a + // model id must not be scavenged as a status, or it misclassifies. + const m = message.match( + /\b(?:HTTP\/?\d?\.?\d?\s*|status(?:\s*code)?\s*(?:[:=]\s*)?)([45]\d{2})\b/i + ); return m ? parseInt(m[1], 10) : undefined; })(); diff --git a/packages/ai/src/provider-utils/ToolCallParsers.ts b/packages/ai/src/provider-utils/ToolCallParsers.ts index b3440862b..732880c14 100644 --- a/packages/ai/src/provider-utils/ToolCallParsers.ts +++ b/packages/ai/src/provider-utils/ToolCallParsers.ts @@ -322,9 +322,13 @@ export const parseLlama: ParserFn = (text) => { // Uses balanced-brace scanning instead of regex to avoid ReDoS if (calls.length === 0) { const blocks = findBalancedBlocks(text, "{", "}"); + let firstBlockStart: number | undefined; for (const block of blocks) { const parsed = tryParseJson(block.text) as Record | undefined; if (parsed?.name && (parsed.parameters !== undefined || parsed.arguments !== undefined)) { + if (firstBlockStart === undefined) { + firstBlockStart = block.start; + } calls.push( makeToolCall( parsed.name as string, @@ -334,8 +338,11 @@ export const parseLlama: ParserFn = (text) => { ); } } - if (calls.length > 0) { - content = text.slice(0, text.indexOf(calls[0].name) - '{"name": "'.length).trim(); + if (calls.length > 0 && firstBlockStart !== undefined) { + // Slice the leading content to the actual block start rather than + // reconstructing the offset from a literal `{"name": "` prefix, which + // breaks when the model emits a different key order or no space. + content = text.slice(0, firstBlockStart).trim(); } } diff --git a/packages/ai/src/provider/AiProvider.ts b/packages/ai/src/provider/AiProvider.ts index a4f8e1613..fbb862721 100644 --- a/packages/ai/src/provider/AiProvider.ts +++ b/packages/ai/src/provider/AiProvider.ts @@ -191,6 +191,11 @@ export abstract class AiProvider const registry = getAiProviderRegistry(); if (registry.getProvider(this.name)) { registry.unregisterProvider(this.name); + // unregisterProvider only clears registry maps; a worker-backed provider + // also has a worker (and, for factory workers, an idle timer) registered + // on the WorkerManager. Without removing it here, the re-registration + // below throws "Worker is already registered." + await globalServiceRegistry.get(WORKER_MANAGER).terminateWorker(this.name); } await this.onInitialize(context); @@ -254,7 +259,12 @@ export abstract class AiProvider } catch (err) { // Clean up the partially-registered provider so the registry isn't left // in an inconsistent state (e.g., functions registered but no queue). + // For worker-backed registration the worker (and its idle timer) was + // already registered above, so remove it too or it leaks under this name. registry.unregisterProvider(this.name); + if (!isInline && options.worker) { + await globalServiceRegistry.get(WORKER_MANAGER).terminateWorker(this.name); + } throw err; } } diff --git a/packages/ai/src/task/AiChatTask.ts b/packages/ai/src/task/AiChatTask.ts index d8bf2cff6..98adc243b 100644 --- a/packages/ai/src/task/AiChatTask.ts +++ b/packages/ai/src/task/AiChatTask.ts @@ -342,6 +342,15 @@ export class AiChatTask extends StreamingAiTask w.length > 0) above should already exclude them, but - // make the contract local). - if (word.length === 0 || /^\s+$/.test(word)) continue; - // Escape every regex metacharacter so user input like `foo(`, `\\`, - // `*abc`, or `[` is treated as a literal token rather than being - // parsed as a regex (which would throw `SyntaxError`). - const regex = new RegExp(escapeRegExp(word), "gi"); - const matches = chunkLower.match(regex); - if (matches) { - keywordScore += matches.length; - } - } + const keywordScore = this.keywordMatchCount(queryWords, chunkLower); const exactMatchBonus = hasQueryWords && chunkLower.includes(queryLower) ? 0.5 : 0; const normalizedKeywordScore = hasQueryWords @@ -245,12 +231,72 @@ export class RerankerTask extends Task index); + if (scores.some((s) => typeof s === "number")) { + scoreOrder.sort((a, b) => (scores[b] ?? -Infinity) - (scores[a] ?? -Infinity)); + } + const scoreRank = new Array(n); + scoreOrder.forEach((originalIndex, rank) => (scoreRank[originalIndex] = rank)); + + // Ranking B — lexical keyword overlap with the query (plus an exact-phrase + // nudge), so RRF fuses a keyword signal with the retrieval signal. + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + const keywordScores = chunks.map((chunk) => { + const chunkLower = chunk.toLowerCase(); + const base = this.keywordMatchCount(queryWords, chunkLower); + return base + (queryWords.length > 0 && chunkLower.includes(queryLower) ? 0.5 : 0); + }); + const keywordOrder = chunks.map((_, index) => index); + keywordOrder.sort((a, b) => keywordScores[b] - keywordScores[a]); + const keywordRank = new Array(n); + keywordOrder.forEach((originalIndex, rank) => (keywordRank[originalIndex] = rank)); + const items: RankedItem[] = chunks.map((chunk, index) => ({ chunk, - score: 1 / (k + index + 1), + score: 1 / (k + scoreRank[index] + 1) + 1 / (k + keywordRank[index] + 1), metadata: metadata[index], originalIndex: index, })); diff --git a/packages/ai/src/task/ToolCallingTask.ts b/packages/ai/src/task/ToolCallingTask.ts index e970160dd..e97632476 100644 --- a/packages/ai/src/task/ToolCallingTask.ts +++ b/packages/ai/src/task/ToolCallingTask.ts @@ -346,17 +346,26 @@ export class ToolCallingTask extends StreamingAiTask< input: ToolCallingTaskInput, executeContext: IExecuteContext ): Promise { - const result = await super.execute(input, executeContext); + // Register the session disposer BEFORE running so it still fires if + // super.execute() throws or the stream aborts mid-iteration — the provider + // may already have allocated the session on the first run-fn invocation. + // The resourceScope is first-registration-wins and disposes via allSettled, + // so computing the session id up front and registering early is safe. + await this.getJobInput(input); this.registerSessionDispose(input, executeContext); - return result; + return super.execute(input, executeContext); } override async *executeStream( input: ToolCallingTaskInput, context: IExecuteContext ): AsyncIterable> { - yield* super.executeStream(input, context); + // Register the session disposer BEFORE streaming for the same reason as + // execute(): an abort or throw mid-stream must still leave the disposer + // registered so disposeSession runs on scope teardown. + await this.getJobInput(input); this.registerSessionDispose(input, context); + yield* super.executeStream(input, context); } } diff --git a/packages/ai/src/task/base/AiTask.ts b/packages/ai/src/task/base/AiTask.ts index 9f541f543..8e9a086c7 100644 --- a/packages/ai/src/task/base/AiTask.ts +++ b/packages/ai/src/task/base/AiTask.ts @@ -111,8 +111,11 @@ export class AiTask< /** * Capabilities this task requires from the model selected at execution time. - * Gates strictly: throws unless `model.capabilities ⊇ task.requires`. - * An empty array passes vacuously (pure-compute subclasses that don't dispatch). + * A model that declares capabilities must include every required one, or + * dispatch throws. A model that declares NO capabilities passes (treated as + * unverified-allow), so inline ModelConfigs without a capabilities list still + * run — matching {@link validateInput} / {@link narrowInput}. An empty + * `requires` passes vacuously (pure-compute subclasses that don't dispatch). */ public static readonly requires: readonly Capability[] = []; @@ -149,21 +152,23 @@ export class AiTask< } /** - * Throws TaskConfigurationError if the model lacks any capability listed in - * the task class's static `requires`. Both execute() and executeStream() must - * call this before dispatch — gating is task-side, not strategy-side. + * Throws TaskConfigurationError when the model declares capabilities that omit + * one required by the task class's static `requires`. A model that declares no + * capabilities passes (see {@link requires}). Both execute() and + * executeStream() must call this before dispatch — gating is task-side, not + * strategy-side. Shares {@link modelMeetsRequires} with validateInput / + * narrowInput so all three gates apply one policy. */ protected gateOrThrow(model: ModelConfig): void { const taskClass = this.constructor as typeof AiTask; const requires = taskClass.requires; + if (modelMeetsRequires(model, requires)) return; const modelCaps = (model.capabilities as readonly Capability[] | undefined) ?? []; const missing = requires.filter((r) => !modelCaps.includes(r)); - if (missing.length > 0) { - throw new TaskConfigurationError( - `Model "${model.model_id ?? "(inline config)"}" is missing capabilities required by ` + - `${taskClass.type}: ${missing.join(", ")}.` - ); - } + throw new TaskConfigurationError( + `Model "${model.model_id ?? "(inline config)"}" is missing capabilities required by ` + + `${taskClass.type}: ${missing.join(", ")}.` + ); } override async execute( diff --git a/packages/test/src/test/ai-provider-api/OllamaProvider.test.ts b/packages/test/src/test/ai-provider-api/OllamaProvider.test.ts index 83ec596eb..78bf221e6 100644 --- a/packages/test/src/test/ai-provider-api/OllamaProvider.test.ts +++ b/packages/test/src/test/ai-provider-api/OllamaProvider.test.ts @@ -97,6 +97,7 @@ describe("OllamaQueuedProvider.inferCapabilities", () => { const caps = provider.inferCapabilities(model("llama3:8b")); const sorted = [...caps].sort(); expect(sorted).toEqual([ + "json-mode", "model.info", "model.search", "text.generation", @@ -116,6 +117,7 @@ describe("OllamaQueuedProvider.inferCapabilities", () => { const caps = provider.inferCapabilities(model("llava:13b")); const sorted = [...caps].sort(); expect(sorted).toEqual([ + "json-mode", "model.info", "model.search", "text.generation", @@ -140,6 +142,7 @@ describe("OLLAMA_RUN_FNS shape", () => { const sets = OLLAMA_RUN_FNS.map((r) => [...r.serves].sort().join(",")); expect(sets).toContain("text.generation"); expect(sets).toContain("text.generation,tool-use"); + expect(sets).toContain("json-mode,text.generation"); expect(sets).toContain("text.rewriter"); expect(sets).toContain("text.summary"); expect(sets).toContain("text.embedding"); @@ -152,11 +155,10 @@ describe("OLLAMA_RUN_FNS shape", () => { expect(candidates.some((r) => r.serves.length === 1)).toBe(true); }); - it("does NOT register image generation, image editing, json-mode, or count-tokens", () => { + it("does NOT register image generation, image editing, or count-tokens", () => { const sets = OLLAMA_RUN_FNS.map((r) => [...r.serves].sort().join(",")); expect(sets).not.toContain("image.generation"); expect(sets).not.toContain("image.editing"); - expect(sets).not.toContain("json-mode,text.generation"); expect(sets).not.toContain("model.count-tokens"); }); }); diff --git a/packages/test/src/test/ai-provider-api/Ollama_StructuredGenerationStream.test.ts b/packages/test/src/test/ai-provider-api/Ollama_StructuredGenerationStream.test.ts new file mode 100644 index 000000000..a38de1e97 --- /dev/null +++ b/packages/test/src/test/ai-provider-api/Ollama_StructuredGenerationStream.test.ts @@ -0,0 +1,129 @@ +/** + * @license + * Copyright 2026 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { createEmitQueue } from "@workglow/ai"; +import { createOllamaStructuredGenerationStream } from "@workglow/ollama/ai-runtime"; +import { describe, expect, it, vi } from "vitest"; + +type FakeStream = { + abort: ReturnType; + [Symbol.asyncIterator](): AsyncIterator<{ message: { content: string } }>; +}; + +function makeFakeStream(chunks: string[]): FakeStream { + let aborted = false; + const abort = vi.fn(() => { + aborted = true; + }); + return { + abort, + async *[Symbol.asyncIterator]() { + for (const c of chunks) { + if (aborted) return; + yield { message: { content: c } }; + } + }, + }; +} + +function isAbortError(err: unknown): boolean { + if (!err || typeof err !== "object") return false; + const e = err as { name?: unknown; message?: unknown }; + if (e.name === "AbortError") return true; + return typeof e.message === "string" && e.message.toLowerCase().includes("abort"); +} + +const model = { + model_id: "ollama:test", + provider_config: { model_name: "llama3.2" }, +} as any; +const input = { prompt: "weather in Paris", outputSchema: { type: "object" } } as any; + +describe("createOllamaStructuredGenerationStream", () => { + it("passes the output schema as `format` and emits progressive object-delta on the `object` port", async () => { + const fakeStream = makeFakeStream(['{"city":"Par', 'is","temp":2', "0}"]); + const chat = vi.fn().mockResolvedValue(fakeStream); + const getClient = vi.fn().mockResolvedValue({ chat }); + const streamFn = createOllamaStructuredGenerationStream(getClient); + + const events: any[] = []; + await streamFn(input, model, undefined as any, (e) => events.push(e)); + + // The output schema is handed to Ollama as the chat `format` parameter. + expect(chat).toHaveBeenCalledTimes(1); + expect(chat.mock.calls[0][0]).toMatchObject({ format: input.outputSchema, stream: true }); + + const objectDeltas = events.filter((e) => e.type === "object-delta"); + expect(objectDeltas.length).toBeGreaterThan(0); + for (const d of objectDeltas) { + expect(d.port).toBe("object"); + expect(typeof d.objectDelta).toBe("object"); + expect(d.objectDelta).not.toBeNull(); + } + + // Per the structured-generation convention, the terminal finish carries the + // fully parsed object in data.object. + const finish = events[events.length - 1]; + expect(finish).toEqual({ type: "finish", data: { object: { city: "Paris", temp: 20 } } }); + }); + + it("recovers finish.data.object via partial JSON when the stream ends with truncated JSON", async () => { + // No closing brace: JSON.parse fails and parsePartialJson recovers the object. + const fakeStream = makeFakeStream(['{"city":"Paris","temp":20']); + const getClient = vi.fn().mockResolvedValue({ chat: vi.fn().mockResolvedValue(fakeStream) }); + const streamFn = createOllamaStructuredGenerationStream(getClient); + + const events: any[] = []; + await streamFn(input, model, undefined as any, (e) => events.push(e)); + + const finish = events[events.length - 1]; + expect(finish.type).toBe("finish"); + expect(finish.data.object).toMatchObject({ city: "Paris", temp: 20 }); + }); + + it("rejects before any client call or emit when the signal is already aborted", async () => { + const controller = new AbortController(); + controller.abort(); + const getClient = vi.fn().mockResolvedValue({ chat: vi.fn() }); + const streamFn = createOllamaStructuredGenerationStream(getClient); + + const events: any[] = []; + let caught: unknown; + await streamFn(input, model, controller.signal, (e) => events.push(e)).catch((e) => { + caught = e; + }); + + expect(isAbortError(caught)).toBe(true); + expect(getClient).not.toHaveBeenCalled(); + expect(events).toHaveLength(0); + }); + + it("aborts the underlying stream exactly once when the signal aborts mid-iteration", async () => { + const chunks = ['{"a":1', ',"b":2', ',"c":3', ',"d":4', ',"e":5}']; + const fakeStream = makeFakeStream(chunks); + const getClient = vi.fn().mockResolvedValue({ chat: vi.fn().mockResolvedValue(fakeStream) }); + const streamFn = createOllamaStructuredGenerationStream(getClient); + + const controller = new AbortController(); + const objectDeltas: any[] = []; + const q = createEmitQueue(); + const runP = streamFn(input, model, controller.signal, (e) => q.push(e)).then( + () => q.close(), + () => q.close() + ); + for await (const ev of q.iterable) { + if (ev.type === "object-delta") { + objectDeltas.push(ev); + if (objectDeltas.length === 2) controller.abort(); + } + } + await runP; + + expect(fakeStream.abort).toHaveBeenCalledTimes(1); + expect(objectDeltas.length).toBeLessThan(chunks.length); + expect(objectDeltas.length).toBeGreaterThanOrEqual(2); + }); +}); diff --git a/packages/test/src/test/ai/StreamEventAccumulator.test.ts b/packages/test/src/test/ai/StreamEventAccumulator.test.ts index 4a2ed4e12..79b29e0cd 100644 --- a/packages/test/src/test/ai/StreamEventAccumulator.test.ts +++ b/packages/test/src/test/ai/StreamEventAccumulator.test.ts @@ -66,12 +66,29 @@ describe("StreamEventAccumulator", () => { }); }); - it("throws when both text-delta and object-delta were observed", () => { + it("composes text-delta and object-delta on distinct ports", () => { const acc = new StreamEventAccumulator(); acc.observe({ type: "text-delta", port: "text", textDelta: "x" } as StreamEvent); acc.observe({ type: "object-delta", port: "obj", objectDelta: { a: 1 } } as StreamEvent); acc.observeFinish({ type: "finish", data: {} }); - expect(() => acc.materialize()).toThrow(/mixed/i); + expect(acc.materialize()).toEqual({ text: "x", obj: { a: 1 } }); + }); + + it("accumulated deltas win over a structural finish scaffold (no clobber)", () => { + // A tool-calling run-fn streams tool calls as object-delta then emits a + // structural default scaffold on finish. The streamed calls must survive, + // and the absent `text` port falls back to the scaffold default. + const acc = new StreamEventAccumulator(); + acc.observe({ + type: "object-delta", + port: "toolCalls", + objectDelta: [{ id: "call_0", name: "search", input: {} }], + } as StreamEvent); + acc.observeFinish({ type: "finish", data: { text: "", toolCalls: [] } }); + expect(acc.materialize()).toEqual({ + text: "", + toolCalls: [{ id: "call_0", name: "search", input: {} }], + }); }); it("throws on error event", () => { diff --git a/packages/test/src/test/ai/collectStream.test.ts b/packages/test/src/test/ai/collectStream.test.ts index c9c55aa87..403644622 100644 --- a/packages/test/src/test/ai/collectStream.test.ts +++ b/packages/test/src/test/ai/collectStream.test.ts @@ -157,14 +157,14 @@ describe("collectStream", () => { expect(result).toEqual({ text: "Hello world", summary: "Hi there" }); }); - it("mixed text+object deltas: throws rather than silently dropping data", async () => { + it("mixed text+object deltas on distinct ports: composes into one object", async () => { type Output = Record; const stream = makeStream( { type: "text-delta", port: "text", textDelta: "hello" }, { type: "object-delta", port: "result", objectDelta: { a: 1 } }, { type: "finish", data: {} as Output } ); - await expect(collectStream(stream)).rejects.toThrow(/mixed/i); + await expect(collectStream(stream)).resolves.toEqual({ text: "hello", result: { a: 1 } }); }); it("first finish wins: duplicate finish event does not corrupt result", async () => { diff --git a/providers/google-gemini/src/ai/common/Gemini_TextGeneration.ts b/providers/google-gemini/src/ai/common/Gemini_TextGeneration.ts index cd9e7c136..3ddad0b34 100644 --- a/providers/google-gemini/src/ai/common/Gemini_TextGeneration.ts +++ b/providers/google-gemini/src/ai/common/Gemini_TextGeneration.ts @@ -14,6 +14,29 @@ import { getApiKey, getModelName, loadGeminiSDK } from "./Gemini_Client"; import type { GeminiModelConfig } from "./Gemini_ModelSchema"; import { buildGeminiContents } from "./Gemini_ToolCalling"; +interface GeminiGenerationConfig { + maxOutputTokens?: number; + temperature?: number; + topP?: number; + frequencyPenalty?: number; + presencePenalty?: number; +} + +/** + * Maps the canonical sampling params onto Gemini's `generationConfig`, only + * setting fields that are defined so callers that omit a param keep the + * provider's default (matching the OpenAI/Anthropic adapters). + */ +function buildGenerationConfig(input: TextGenerationTaskInput): GeminiGenerationConfig { + const config: GeminiGenerationConfig = {}; + if (input.maxTokens !== undefined) config.maxOutputTokens = input.maxTokens; + if (input.temperature !== undefined) config.temperature = input.temperature; + if (input.topP !== undefined) config.topP = input.topP; + if (input.frequencyPenalty !== undefined) config.frequencyPenalty = input.frequencyPenalty; + if (input.presencePenalty !== undefined) config.presencePenalty = input.presencePenalty; + return config; +} + /** * Inputs that the unified `["text.generation"]` runFn handles. Both * {@link TextGenerationTask} and {@link AiChatTask} declare @@ -58,10 +81,7 @@ export const Gemini_TextGeneration_Stream: AiProviderRunFn< const genModel = genAI.getGenerativeModel({ model: getModelName(model), systemInstruction: unified.systemPrompt || undefined, - generationConfig: { - maxOutputTokens: input.maxTokens, - temperature: input.temperature, - }, + generationConfig: buildGenerationConfig(input), }); const contents = buildGeminiContents( @@ -81,11 +101,7 @@ export const Gemini_TextGeneration_Stream: AiProviderRunFn< // Prompt path — simple single-user-message generation. const genModel = genAI.getGenerativeModel({ model: getModelName(model), - generationConfig: { - maxOutputTokens: input.maxTokens, - temperature: input.temperature, - topP: input.topP, - }, + generationConfig: buildGenerationConfig(input), }); const result = await genModel.generateContentStream( diff --git a/providers/ollama/src/ai/common/Ollama_Capabilities.ts b/providers/ollama/src/ai/common/Ollama_Capabilities.ts index 4a7acc196..d01ebece9 100644 --- a/providers/ollama/src/ai/common/Ollama_Capabilities.ts +++ b/providers/ollama/src/ai/common/Ollama_Capabilities.ts @@ -70,6 +70,7 @@ export function inferOllamaCapabilities(model: CapabilityHints): readonly Capabi "text.rewriter", "text.summary", "tool-use", + "json-mode", "vision-input", "model.info", "model.search", @@ -85,6 +86,7 @@ export function inferOllamaCapabilities(model: CapabilityHints): readonly Capabi "text.rewriter", "text.summary", "tool-use", + "json-mode", "model.info", "model.search", ]; diff --git a/providers/ollama/src/ai/common/Ollama_CapabilitySets.ts b/providers/ollama/src/ai/common/Ollama_CapabilitySets.ts index 677d7178b..c656b47d6 100644 --- a/providers/ollama/src/ai/common/Ollama_CapabilitySets.ts +++ b/providers/ollama/src/ai/common/Ollama_CapabilitySets.ts @@ -16,6 +16,7 @@ import type { Capability } from "@workglow/ai/worker"; */ export const OLLAMA_TEXT_GENERATION = ["text.generation"] as const satisfies Capability[]; export const OLLAMA_TOOL_USE = ["text.generation", "tool-use"] as const satisfies Capability[]; +export const OLLAMA_JSON_MODE = ["text.generation", "json-mode"] as const satisfies Capability[]; export const OLLAMA_TEXT_REWRITER = ["text.rewriter"] as const satisfies Capability[]; export const OLLAMA_TEXT_SUMMARY = ["text.summary"] as const satisfies Capability[]; export const OLLAMA_TEXT_EMBEDDING = ["text.embedding"] as const satisfies Capability[]; @@ -26,6 +27,7 @@ export const OLLAMA_MODEL_INFO = ["model.info"] as const satisfies Capability[]; export const OLLAMA_CAPABILITY_SETS = [ OLLAMA_TEXT_GENERATION, OLLAMA_TOOL_USE, + OLLAMA_JSON_MODE, OLLAMA_TEXT_REWRITER, OLLAMA_TEXT_SUMMARY, OLLAMA_TEXT_EMBEDDING, diff --git a/providers/ollama/src/ai/common/Ollama_JobRunFns.browser.ts b/providers/ollama/src/ai/common/Ollama_JobRunFns.browser.ts index 26215bebc..cb630b559 100644 --- a/providers/ollama/src/ai/common/Ollama_JobRunFns.browser.ts +++ b/providers/ollama/src/ai/common/Ollama_JobRunFns.browser.ts @@ -6,6 +6,7 @@ import type { AiProviderRunFnRegistration, ToolCallingTaskInput } from "@workglow/ai"; import { + OLLAMA_JSON_MODE, OLLAMA_MODEL_INFO, OLLAMA_MODEL_SEARCH, OLLAMA_TEXT_EMBEDDING, @@ -18,6 +19,7 @@ import { getClient } from "./Ollama_Client.browser"; import { createOllamaModelInfoStream } from "./Ollama_ModelInfo"; import type { OllamaModelConfig } from "./Ollama_ModelSchema"; import { createOllamaModelSearchStream } from "./Ollama_ModelSearch"; +import { createOllamaStructuredGenerationStream } from "./Ollama_StructuredGeneration"; import { createOllamaTextEmbeddingStream } from "./Ollama_TextEmbedding"; import { createOllamaTextGenerationStream } from "./Ollama_TextGeneration"; import { createOllamaTextRewriterStream } from "./Ollama_TextRewriter"; @@ -39,6 +41,7 @@ function buildBrowserToolCallingMessages(input: ToolCallingTaskInput): Array<{ } export const Ollama_TextGeneration_Stream = createOllamaTextGenerationStream(getClient); +export const Ollama_StructuredGeneration_Stream = createOllamaStructuredGenerationStream(getClient); export const Ollama_TextRewriter_Stream = createOllamaTextRewriterStream(getClient); export const Ollama_TextSummary_Stream = createOllamaTextSummaryStream(getClient); export const Ollama_TextEmbedding_Stream = createOllamaTextEmbeddingStream(getClient); @@ -52,6 +55,7 @@ export const Ollama_ModelSearch_Stream = createOllamaModelSearchStream(getClient export const OLLAMA_RUN_FNS: readonly AiProviderRunFnRegistration[] = [ { serves: OLLAMA_TEXT_GENERATION, runFn: Ollama_TextGeneration_Stream }, { serves: OLLAMA_TOOL_USE, runFn: Ollama_ToolCalling_Stream }, + { serves: OLLAMA_JSON_MODE, runFn: Ollama_StructuredGeneration_Stream }, { serves: OLLAMA_TEXT_REWRITER, runFn: Ollama_TextRewriter_Stream }, { serves: OLLAMA_TEXT_SUMMARY, runFn: Ollama_TextSummary_Stream }, { serves: OLLAMA_TEXT_EMBEDDING, runFn: Ollama_TextEmbedding_Stream }, diff --git a/providers/ollama/src/ai/common/Ollama_JobRunFns.ts b/providers/ollama/src/ai/common/Ollama_JobRunFns.ts index 3dd42cdf6..54c6cff1f 100644 --- a/providers/ollama/src/ai/common/Ollama_JobRunFns.ts +++ b/providers/ollama/src/ai/common/Ollama_JobRunFns.ts @@ -7,6 +7,7 @@ import type { AiProviderRunFnRegistration } from "@workglow/ai"; import { toTextFlatMessages } from "@workglow/ai/worker"; import { + OLLAMA_JSON_MODE, OLLAMA_MODEL_INFO, OLLAMA_MODEL_SEARCH, OLLAMA_TEXT_EMBEDDING, @@ -19,6 +20,7 @@ import { getClient } from "./Ollama_Client"; import { createOllamaModelInfoStream } from "./Ollama_ModelInfo"; import type { OllamaModelConfig } from "./Ollama_ModelSchema"; import { createOllamaModelSearchStream } from "./Ollama_ModelSearch"; +import { createOllamaStructuredGenerationStream } from "./Ollama_StructuredGeneration"; import { createOllamaTextEmbeddingStream } from "./Ollama_TextEmbedding"; import { createOllamaTextGenerationStream } from "./Ollama_TextGeneration"; import { createOllamaTextRewriterStream } from "./Ollama_TextRewriter"; @@ -26,6 +28,7 @@ import { createOllamaTextSummaryStream } from "./Ollama_TextSummary"; import { createOllamaToolCallingStream } from "./Ollama_ToolCalling"; export const Ollama_TextGeneration_Stream = createOllamaTextGenerationStream(getClient); +export const Ollama_StructuredGeneration_Stream = createOllamaStructuredGenerationStream(getClient); export const Ollama_TextRewriter_Stream = createOllamaTextRewriterStream(getClient); export const Ollama_TextSummary_Stream = createOllamaTextSummaryStream(getClient); export const Ollama_TextEmbedding_Stream = createOllamaTextEmbeddingStream(getClient); @@ -47,6 +50,7 @@ export const Ollama_ModelSearch_Stream = createOllamaModelSearchStream(getClient export const OLLAMA_RUN_FNS: readonly AiProviderRunFnRegistration[] = [ { serves: OLLAMA_TEXT_GENERATION, runFn: Ollama_TextGeneration_Stream }, { serves: OLLAMA_TOOL_USE, runFn: Ollama_ToolCalling_Stream }, + { serves: OLLAMA_JSON_MODE, runFn: Ollama_StructuredGeneration_Stream }, { serves: OLLAMA_TEXT_REWRITER, runFn: Ollama_TextRewriter_Stream }, { serves: OLLAMA_TEXT_SUMMARY, runFn: Ollama_TextSummary_Stream }, { serves: OLLAMA_TEXT_EMBEDDING, runFn: Ollama_TextEmbedding_Stream }, diff --git a/providers/ollama/src/ai/common/Ollama_StructuredGeneration.ts b/providers/ollama/src/ai/common/Ollama_StructuredGeneration.ts new file mode 100644 index 000000000..804483204 --- /dev/null +++ b/providers/ollama/src/ai/common/Ollama_StructuredGeneration.ts @@ -0,0 +1,82 @@ +/** + * @license + * Copyright 2026 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + AiProviderRunFn, + StructuredGenerationTaskInput, + StructuredGenerationTaskOutput, +} from "@workglow/ai"; +import { parsePartialJson } from "@workglow/util/worker"; +import type { OllamaModelConfig } from "./Ollama_ModelSchema"; +import { getOllamaModelName } from "./Ollama_ModelUtil"; + +type GetClient = (model: OllamaModelConfig | undefined) => Promise; + +/** + * Streaming run-fn factory for the `["text.generation", "json-mode"]` + * capability. Ollama constrains output by passing the JSON schema as the chat + * API's `format` parameter. Emits `object-delta` events with progressively + * parsed partial JSON on the `object` port. + * + * Per the structured-generation streaming-convention exception, the final + * `finish` event MUST carry the parsed object in `finish.data.object` so the + * {@link StructuredGenerationTask} consumer can read it without running a JSON + * streaming parser. + */ +export function createOllamaStructuredGenerationStream( + getClient: GetClient +): AiProviderRunFn< + StructuredGenerationTaskInput, + StructuredGenerationTaskOutput, + OllamaModelConfig +> { + return async (input, model, signal, emit, outputSchema) => { + signal?.throwIfAborted?.(); + const client = await getClient(model); + const modelName = getOllamaModelName(model); + + const schema = input.outputSchema ?? outputSchema; + + const stream = await client.chat({ + model: modelName, + messages: [{ role: "user", content: input.prompt }], + format: schema, + options: { + temperature: input.temperature, + num_predict: input.maxTokens, + }, + stream: true, + }); + + const onAbort = (): void => stream.abort(); + signal?.addEventListener("abort", onAbort, { once: true }); + let accumulatedJson = ""; + try { + if (signal?.aborted) stream.abort(); + signal?.throwIfAborted?.(); + for await (const chunk of stream) { + const delta = chunk.message.content; + if (delta) { + accumulatedJson += delta; + const partial = parsePartialJson(accumulatedJson); + if (partial !== undefined) { + emit({ type: "object-delta", port: "object", objectDelta: partial }); + } + } + } + } finally { + signal?.removeEventListener("abort", onAbort); + } + + let finalObject: Record; + try { + finalObject = JSON.parse(accumulatedJson); + } catch { + finalObject = parsePartialJson(accumulatedJson) ?? {}; + } + emit({ type: "finish", data: { object: finalObject } as StructuredGenerationTaskOutput }); + }; +} diff --git a/providers/ollama/src/ai/common/Ollama_ToolCalling.ts b/providers/ollama/src/ai/common/Ollama_ToolCalling.ts index 30d407ead..73255b69d 100644 --- a/providers/ollama/src/ai/common/Ollama_ToolCalling.ts +++ b/providers/ollama/src/ai/common/Ollama_ToolCalling.ts @@ -59,20 +59,18 @@ export function createOllamaToolCallingStream( const onAbort = (): void => stream.abort(); signal?.addEventListener("abort", onAbort, { once: true }); - let accumulatedText = ""; - const toolCalls: ToolCalls = []; let callIndex = 0; try { for await (const chunk of stream) { const delta = chunk.message.content; if (delta) { - accumulatedText += delta; emit({ type: "text-delta", port: "text", textDelta: delta }); } const chunkToolCalls = chunk.message.tool_calls; if (Array.isArray(chunkToolCalls) && chunkToolCalls.length > 0) { + const parsed: ToolCalls = []; for (const tc of chunkToolCalls) { let parsedInput: Record = {}; const fnArgs = tc.function.arguments; @@ -86,21 +84,28 @@ export function createOllamaToolCallingStream( } else if (fnArgs != null) { parsedInput = fnArgs as Record; } - const id = `call_${callIndex++}`; - toolCalls.push({ - id, + parsed.push({ + id: `call_${callIndex++}`, name: tc.function.name as string, input: sanitizeToolArgs(parsedInput) as Record, }); } - emit({ type: "object-delta", port: "toolCalls", objectDelta: [...toolCalls] }); + // Filter hallucinated tool names at the delta, not at finish: the + // consumer accumulates deltas and the streaming finish carries only a + // structural default, so unvalidated calls must never be emitted. + const valid = filterValidToolCalls(parsed, input.tools); + if (valid.length > 0) { + emit({ type: "object-delta", port: "toolCalls", objectDelta: valid }); + } } } - const validToolCalls = filterValidToolCalls(toolCalls, input.tools); + // Static default scaffold only — the run-fn does not accumulate. The + // consumer's accumulated deltas take precedence; this supplies the + // required output ports when the model streamed neither text nor calls. emit({ type: "finish", - data: { text: accumulatedText, toolCalls: validToolCalls } as ToolCallingTaskOutput, + data: { text: "", toolCalls: [] } as ToolCallingTaskOutput, }); } finally { signal?.removeEventListener("abort", onAbort); diff --git a/providers/ollama/src/ai/runtime.ts b/providers/ollama/src/ai/runtime.ts index 618384fcb..f394dee74 100644 --- a/providers/ollama/src/ai/runtime.ts +++ b/providers/ollama/src/ai/runtime.ts @@ -15,5 +15,6 @@ export * from "./common/Ollama_Client"; export * from "./common/Ollama_TextGeneration"; +export * from "./common/Ollama_StructuredGeneration"; export * from "./registerOllamaInline"; export * from "./registerOllamaWorker";