Skip to content

Commit 8988ac7

Browse files
committed
improvement(providers): harden OpenAI-compatible providers + add tests
1 parent 8daca91 commit 8988ac7

15 files changed

Lines changed: 1911 additions & 318 deletions

File tree

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/**
2+
* @vitest-environment node
3+
*/
4+
import { beforeEach, describe, expect, it, vi } from 'vitest'
5+
6+
const {
7+
mockCreate,
8+
mockSupportsNativeStructuredOutputs,
9+
mockPrepareToolsWithUsageControl,
10+
mockExecuteTool,
11+
} = vi.hoisted(() => ({
12+
mockCreate: vi.fn(),
13+
mockSupportsNativeStructuredOutputs: vi.fn(),
14+
mockPrepareToolsWithUsageControl: vi.fn(),
15+
mockExecuteTool: vi.fn(),
16+
}))
17+
18+
vi.mock('openai', () => ({
19+
default: vi.fn().mockImplementation(() => ({
20+
chat: { completions: { create: mockCreate } },
21+
})),
22+
}))
23+
24+
vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 }))
25+
26+
vi.mock('@/providers/models', () => ({
27+
getProviderModels: vi.fn().mockReturnValue([]),
28+
getProviderDefaultModel: vi.fn().mockReturnValue('llama-v3p1-70b-instruct'),
29+
}))
30+
31+
vi.mock('@/providers/attachments', () => ({
32+
formatMessagesForProvider: vi.fn((messages) => messages),
33+
}))
34+
35+
vi.mock('@/providers/fireworks/utils', () => ({
36+
supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs,
37+
createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream),
38+
checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })),
39+
}))
40+
41+
vi.mock('@/providers/trace-enrichment', () => ({
42+
enrichLastModelSegmentFromChatCompletions: vi.fn(),
43+
}))
44+
45+
vi.mock('@/providers/utils', () => ({
46+
calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }),
47+
generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'),
48+
prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })),
49+
prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl,
50+
sumToolCosts: vi.fn().mockReturnValue(0),
51+
}))
52+
53+
vi.mock('@/tools', () => ({ executeTool: mockExecuteTool }))
54+
55+
import { fireworksProvider } from '@/providers/fireworks/index'
56+
import { ProviderError } from '@/providers/types'
57+
58+
const textResponse = (content: string) => ({
59+
choices: [{ message: { content, tool_calls: [] } }],
60+
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
61+
})
62+
63+
const toolCallResponse = () => ({
64+
choices: [
65+
{
66+
message: {
67+
content: null,
68+
tool_calls: [
69+
{ id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } },
70+
],
71+
},
72+
},
73+
],
74+
usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 },
75+
})
76+
77+
const toolDef = {
78+
id: 'my_tool',
79+
name: 'my_tool',
80+
description: '',
81+
params: {},
82+
parameters: { type: 'object', properties: {}, required: [] },
83+
}
84+
85+
const callBody = (index: number) => mockCreate.mock.calls[index][0]
86+
const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0]
87+
88+
describe('fireworksProvider', () => {
89+
beforeEach(() => {
90+
vi.clearAllMocks()
91+
mockSupportsNativeStructuredOutputs.mockResolvedValue(true)
92+
mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({
93+
tools,
94+
toolChoice: 'auto',
95+
forcedTools: [],
96+
}))
97+
mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } })
98+
})
99+
100+
const baseRequest = {
101+
model: 'fireworks/llama-v3p1-70b-instruct',
102+
systemPrompt: 'You are helpful.',
103+
messages: [{ role: 'user' as const, content: 'Hello' }],
104+
apiKey: 'fw-test-key',
105+
}
106+
107+
it('throws when the API key is missing', async () => {
108+
await expect(
109+
fireworksProvider.executeRequest({ ...baseRequest, apiKey: undefined })
110+
).rejects.toThrow('API key is required for Fireworks')
111+
})
112+
113+
it('returns content and token usage for a simple request', async () => {
114+
mockCreate.mockResolvedValueOnce(textResponse('hi there'))
115+
116+
const result = await fireworksProvider.executeRequest(baseRequest)
117+
118+
expect(result).toMatchObject({
119+
content: 'hi there',
120+
model: 'llama-v3p1-70b-instruct',
121+
tokens: { input: 10, output: 5, total: 15 },
122+
})
123+
})
124+
125+
it('wraps API errors in a ProviderError', async () => {
126+
mockCreate.mockRejectedValueOnce(new Error('boom'))
127+
128+
await expect(fireworksProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf(
129+
ProviderError
130+
)
131+
})
132+
133+
it('streams directly when there are no tools', async () => {
134+
mockCreate.mockResolvedValueOnce({})
135+
136+
const result = await fireworksProvider.executeRequest({ ...baseRequest, stream: true })
137+
138+
expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } })
139+
expect(result).toHaveProperty('stream')
140+
expect(result).toHaveProperty('execution')
141+
})
142+
143+
it('sends a json_schema response_format with no strict field', async () => {
144+
mockCreate.mockResolvedValueOnce(textResponse('{}'))
145+
146+
await fireworksProvider.executeRequest({
147+
...baseRequest,
148+
responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true },
149+
})
150+
151+
expect(lastCallBody().response_format).toEqual({
152+
type: 'json_schema',
153+
json_schema: { name: 'my_schema', schema: { type: 'object' } },
154+
})
155+
expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict')
156+
})
157+
158+
it('falls back to json_object with prompt instructions when native is unsupported', async () => {
159+
mockSupportsNativeStructuredOutputs.mockResolvedValue(false)
160+
mockCreate.mockResolvedValueOnce(textResponse('{}'))
161+
162+
await fireworksProvider.executeRequest({
163+
...baseRequest,
164+
responseFormat: { name: 'my_schema', schema: { type: 'object' } },
165+
})
166+
167+
expect(lastCallBody().response_format).toEqual({ type: 'json_object' })
168+
expect(lastCallBody().messages.at(-1)).toEqual({
169+
role: 'user',
170+
content: 'SCHEMA_INSTRUCTIONS',
171+
})
172+
})
173+
174+
it('defers response_format to a final call when tools are active', async () => {
175+
mockCreate
176+
.mockResolvedValueOnce(textResponse('intermediate'))
177+
.mockResolvedValueOnce(textResponse('{"done":true}'))
178+
179+
await fireworksProvider.executeRequest({
180+
...baseRequest,
181+
responseFormat: { name: 'my_schema', schema: { type: 'object' } },
182+
tools: [toolDef],
183+
})
184+
185+
expect(mockCreate).toHaveBeenCalledTimes(2)
186+
expect(callBody(0).response_format).toBeUndefined()
187+
expect(callBody(0).tools).toBeDefined()
188+
expect(callBody(1).response_format).toEqual({
189+
type: 'json_schema',
190+
json_schema: { name: 'my_schema', schema: { type: 'object' } },
191+
})
192+
expect(callBody(1).tools).toBeUndefined()
193+
})
194+
195+
it('runs the tool loop and threads tool results back into the conversation', async () => {
196+
mockCreate
197+
.mockResolvedValueOnce(toolCallResponse())
198+
.mockResolvedValueOnce(textResponse('final answer'))
199+
200+
const result = await fireworksProvider.executeRequest({ ...baseRequest, tools: [toolDef] })
201+
202+
expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything())
203+
expect(result).toMatchObject({ content: 'final answer' })
204+
expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1)
205+
206+
const followUpMessages = callBody(1).messages
207+
expect(followUpMessages).toContainEqual(
208+
expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) })
209+
)
210+
expect(followUpMessages).toContainEqual(
211+
expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' })
212+
)
213+
})
214+
215+
it("forces tool_choice 'none' on the final streaming call after tools run", async () => {
216+
mockCreate
217+
.mockResolvedValueOnce(toolCallResponse())
218+
.mockResolvedValueOnce(textResponse('done'))
219+
.mockResolvedValueOnce({})
220+
221+
await fireworksProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] })
222+
223+
expect(mockCreate).toHaveBeenCalledTimes(3)
224+
expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true })
225+
})
226+
})

apps/sim/providers/fireworks/index.ts

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ const logger = createLogger('FireworksProvider')
3434

3535
/**
3636
* Applies structured output configuration to a payload based on model capabilities.
37-
* Uses json_schema with strict mode for supported models, falls back to json_object with prompt instructions.
37+
* Uses native json_schema for supported models, falls back to json_object with prompt instructions.
3838
*/
3939
async function applyResponseFormat(
4040
targetPayload: any,
@@ -51,7 +51,6 @@ async function applyResponseFormat(
5151
json_schema: {
5252
name: responseFormat.name || 'response_schema',
5353
schema: responseFormat.schema || responseFormat,
54-
strict: responseFormat.strict !== false,
5554
},
5655
}
5756
return messages
@@ -469,7 +468,7 @@ export const fireworksProvider: ProviderConfig = {
469468
const streamingParams: ChatCompletionCreateParamsStreaming = {
470469
...payload,
471470
messages: [...currentMessages],
472-
tool_choice: 'auto',
471+
tool_choice: 'none',
473472
stream: true,
474473
stream_options: { include_usage: true },
475474
}
@@ -652,8 +651,3 @@ export const fireworksProvider: ProviderConfig = {
652651
}
653652
},
654653
}
655-
656-
/**
657-
* Enriches the last model segment with per-iteration content from a Chat
658-
* Completions response: assistant text, tool calls, finish reason, token usage.
659-
*/

apps/sim/providers/fireworks/utils.ts

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,18 @@ import type { ChatCompletionChunk } from 'openai/resources/chat/completions'
22
import type { CompletionUsage } from 'openai/resources/completions'
33
import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils'
44

5-
/**
6-
* Checks if a model supports native structured outputs (json_schema).
7-
* Fireworks AI supports structured outputs across their inference API.
8-
*/
5+
/** Fireworks supports native json_schema structured outputs for all models on its inference API. */
96
export async function supportsNativeStructuredOutputs(_modelId: string): Promise<boolean> {
107
return true
118
}
129

13-
/**
14-
* Creates a ReadableStream from a Fireworks streaming response.
15-
* Uses the shared OpenAI-compatible streaming utility.
16-
*/
1710
export function createReadableStreamFromOpenAIStream(
1811
openaiStream: AsyncIterable<ChatCompletionChunk>,
1912
onComplete?: (content: string, usage: CompletionUsage) => void
2013
): ReadableStream<Uint8Array> {
2114
return createOpenAICompatibleStream(openaiStream, 'Fireworks', onComplete)
2215
}
2316

24-
/**
25-
* Checks if a forced tool was used in a Fireworks response.
26-
* Uses the shared OpenAI-compatible forced tool usage helper.
27-
*/
2817
export function checkForForcedToolUsage(
2918
response: any,
3019
toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any },

0 commit comments

Comments
 (0)