Skip to content

Commit 94223b9

Browse files
committed
improvement(mothership): abort path race preventing persistence
1 parent 08eeecb commit 94223b9

4 files changed

Lines changed: 307 additions & 150 deletions

File tree

apps/sim/app/api/copilot/chat/stop/route.test.ts

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,49 @@ const {
1010
mockFrom,
1111
mockWhereSelect,
1212
mockLimit,
13+
mockForUpdate,
1314
mockUpdate,
1415
mockSet,
1516
mockWhereUpdate,
1617
mockReturning,
1718
mockPublishStatusChanged,
1819
mockSql,
19-
} = vi.hoisted(() => ({
20-
mockSelect: vi.fn(),
21-
mockFrom: vi.fn(),
22-
mockWhereSelect: vi.fn(),
23-
mockLimit: vi.fn(),
24-
mockUpdate: vi.fn(),
25-
mockSet: vi.fn(),
26-
mockWhereUpdate: vi.fn(),
27-
mockReturning: vi.fn(),
28-
mockPublishStatusChanged: vi.fn(),
29-
mockSql: vi.fn((strings: TemplateStringsArray, ...values: unknown[]) => ({ strings, values })),
30-
}))
20+
mockTransaction,
21+
} = vi.hoisted(() => {
22+
const mockSelect = vi.fn()
23+
const mockFrom = vi.fn()
24+
const mockWhereSelect = vi.fn()
25+
const mockLimit = vi.fn()
26+
const mockForUpdate = vi.fn()
27+
const mockUpdate = vi.fn()
28+
const mockSet = vi.fn()
29+
const mockWhereUpdate = vi.fn()
30+
const mockReturning = vi.fn()
31+
const mockPublishStatusChanged = vi.fn()
32+
const mockSql = vi.fn((strings: TemplateStringsArray, ...values: unknown[]) => ({
33+
strings,
34+
values,
35+
}))
36+
const mockTransaction = vi.fn(
37+
(callback: (tx: { select: typeof mockSelect; update: typeof mockUpdate }) => unknown) =>
38+
callback({ select: mockSelect, update: mockUpdate })
39+
)
40+
41+
return {
42+
mockSelect,
43+
mockFrom,
44+
mockWhereSelect,
45+
mockLimit,
46+
mockForUpdate,
47+
mockUpdate,
48+
mockSet,
49+
mockWhereUpdate,
50+
mockReturning,
51+
mockPublishStatusChanged,
52+
mockSql,
53+
mockTransaction,
54+
}
55+
})
3156

3257
vi.mock('@sim/db/schema', () => ({
3358
copilotChats: {
@@ -41,8 +66,7 @@ vi.mock('@sim/db/schema', () => ({
4166

4267
vi.mock('@sim/db', () => ({
4368
db: {
44-
select: mockSelect,
45-
update: mockUpdate,
69+
transaction: mockTransaction,
4670
},
4771
}))
4872

@@ -78,9 +102,11 @@ describe('copilot chat stop route', () => {
78102
{
79103
workspaceId: 'ws-1',
80104
messages: [{ id: 'stream-1', role: 'user', content: 'hello' }],
105+
conversationId: 'stream-1',
81106
},
82107
])
83-
mockWhereSelect.mockReturnValue({ limit: mockLimit })
108+
mockForUpdate.mockReturnValue({ limit: mockLimit })
109+
mockWhereSelect.mockReturnValue({ for: mockForUpdate })
84110
mockFrom.mockReturnValue({ where: mockWhereSelect })
85111
mockSelect.mockReturnValue({ from: mockFrom })
86112

@@ -153,4 +179,33 @@ describe('copilot chat stop route', () => {
153179
streamId: 'stream-1',
154180
})
155181
})
182+
183+
it('appends a stopped assistant message if the stream marker was already cleared', async () => {
184+
mockLimit.mockResolvedValueOnce([
185+
{
186+
workspaceId: 'ws-1',
187+
messages: [{ id: 'stream-1', role: 'user', content: 'hello' }],
188+
conversationId: null,
189+
},
190+
])
191+
192+
const response = await POST(
193+
createRequest({
194+
chatId: 'chat-1',
195+
streamId: 'stream-1',
196+
content: 'partial',
197+
})
198+
)
199+
200+
expect(response.status).toBe(200)
201+
expect(await response.json()).toEqual({ success: true })
202+
203+
const setArg = mockSet.mock.calls[0]?.[0]
204+
expect(setArg.messages).toBeTruthy()
205+
const appendedPayload = JSON.parse(setArg.messages.values[1] as string)
206+
expect(appendedPayload[0]).toMatchObject({
207+
role: 'assistant',
208+
content: 'partial',
209+
})
210+
})
156211
})

apps/sim/app/api/copilot/chat/stop/route.ts

Lines changed: 24 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
import { db } from '@sim/db'
2-
import { copilotChats } from '@sim/db/schema'
31
import { createLogger } from '@sim/logger'
42
import { generateId } from '@sim/utils/id'
5-
import { and, eq, sql } from 'drizzle-orm'
63
import { type NextRequest, NextResponse } from 'next/server'
74
import { copilotChatStopContract } from '@/lib/api/contracts/copilot'
85
import { parseRequest } from '@/lib/api/server'
96
import { getSession } from '@/lib/auth'
107
import { normalizeMessage, type PersistedMessage } from '@/lib/copilot/chat/persisted-message'
8+
import { finalizeAssistantTurn } from '@/lib/copilot/chat/terminal-state'
119
import { CopilotStopOutcome } from '@/lib/copilot/generated/trace-attribute-values-v1'
1210
import { TraceAttr } from '@/lib/copilot/generated/trace-attributes-v1'
1311
import { TraceSpan } from '@/lib/copilot/generated/trace-spans-v1'
@@ -44,71 +42,33 @@ export const POST = withRouteHandler((req: NextRequest) =>
4442
...(requestId ? { [TraceAttr.RequestId]: requestId } : {}),
4543
})
4644

47-
const [row] = await db
48-
.select({
49-
workspaceId: copilotChats.workspaceId,
50-
messages: copilotChats.messages,
51-
})
52-
.from(copilotChats)
53-
.where(and(eq(copilotChats.id, chatId), eq(copilotChats.userId, session.user.id)))
54-
.limit(1)
55-
56-
if (!row) {
57-
span.setAttribute(TraceAttr.CopilotStopOutcome, CopilotStopOutcome.ChatNotFound)
58-
return NextResponse.json({ success: true })
59-
}
60-
61-
const messages: Record<string, unknown>[] = Array.isArray(row.messages) ? row.messages : []
62-
const userIdx = messages.findIndex((message) => message.id === streamId)
63-
const alreadyHasResponse =
64-
userIdx >= 0 &&
65-
userIdx + 1 < messages.length &&
66-
(messages[userIdx + 1] as Record<string, unknown>)?.role === 'assistant'
67-
const canAppendAssistant =
68-
userIdx >= 0 && userIdx === messages.length - 1 && !alreadyHasResponse
69-
70-
const updateWhere = and(
71-
eq(copilotChats.id, chatId),
72-
eq(copilotChats.userId, session.user.id),
73-
eq(copilotChats.conversationId, streamId)
74-
)
75-
76-
const setClause: Record<string, unknown> = {
77-
conversationId: null,
78-
updatedAt: new Date(),
79-
}
80-
8145
const hasContent = content.trim().length > 0
8246
const hasBlocks = Array.isArray(contentBlocks) && contentBlocks.length > 0
8347
const synthesizedStoppedBlocks = hasBlocks
8448
? contentBlocks
8549
: hasContent
8650
? [{ type: 'text', channel: 'assistant', content }, { type: 'stopped' }]
8751
: [{ type: 'stopped' }]
88-
if (canAppendAssistant) {
89-
const normalized = normalizeMessage({
90-
id: generateId(),
91-
role: 'assistant',
92-
content,
93-
timestamp: new Date().toISOString(),
94-
contentBlocks: synthesizedStoppedBlocks,
95-
// Persist so the UI copy-request-id button survives refetch.
96-
...(requestId ? { requestId } : {}),
97-
})
98-
const assistantMessage: PersistedMessage = normalized
99-
setClause.messages = sql`${copilotChats.messages} || ${JSON.stringify([assistantMessage])}::jsonb`
100-
}
101-
span.setAttribute(TraceAttr.CopilotStopAppendedAssistant, canAppendAssistant)
102-
103-
const [updated] = await db
104-
.update(copilotChats)
105-
.set(setClause)
106-
.where(updateWhere)
107-
.returning({ workspaceId: copilotChats.workspaceId })
52+
const assistantMessage: PersistedMessage = normalizeMessage({
53+
id: generateId(),
54+
role: 'assistant',
55+
content,
56+
timestamp: new Date().toISOString(),
57+
contentBlocks: synthesizedStoppedBlocks,
58+
...(requestId ? { requestId } : {}),
59+
})
60+
const result = await finalizeAssistantTurn({
61+
chatId,
62+
userId: session.user.id,
63+
userMessageId: streamId,
64+
assistantMessage,
65+
streamMarkerPolicy: 'active-or-cleared',
66+
})
67+
span.setAttribute(TraceAttr.CopilotStopAppendedAssistant, result.appendedAssistant)
10868

109-
if (updated?.workspaceId) {
69+
if (result.updated && result.workspaceId) {
11070
taskPubSub?.publishStatusChanged({
111-
workspaceId: updated.workspaceId,
71+
workspaceId: result.workspaceId,
11272
chatId,
11373
type: 'completed',
11474
streamId,
@@ -117,7 +77,11 @@ export const POST = withRouteHandler((req: NextRequest) =>
11777

11878
span.setAttribute(
11979
TraceAttr.CopilotStopOutcome,
120-
updated ? CopilotStopOutcome.Persisted : CopilotStopOutcome.NoMatchingRow
80+
result.found
81+
? result.updated
82+
? CopilotStopOutcome.Persisted
83+
: CopilotStopOutcome.NoMatchingRow
84+
: CopilotStopOutcome.ChatNotFound
12185
)
12286
return NextResponse.json({ success: true })
12387
} catch (error) {

0 commit comments

Comments
 (0)