Skip to content

Commit 0bd2c7c

Browse files
committed
refactor(mcp): tighten OAuth callback contract and registration metadata
- Validate callback query params via mcpOauthCallbackContract instead of raw searchParams.get, matching the rest of the MCP route surface. - Drop non-RFC-7591 application_type field from dynamic client registration to avoid rejection by strict authorization servers. - Collapse the pre-lock OAuth row load in createClient — the row is now loaded exclusively inside withMcpOauthRefreshLock, removing a redundant query and a stale-snapshot path.
1 parent f620a1b commit 0bd2c7c

4 files changed

Lines changed: 35 additions & 20 deletions

File tree

apps/sim/app/api/mcp/oauth/callback/route.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import { toError } from '@sim/utils/errors'
66
import { and, eq, isNull } from 'drizzle-orm'
77
import type { NextRequest } from 'next/server'
88
import { NextResponse } from 'next/server'
9+
import { mcpOauthCallbackContract } from '@/lib/api/contracts/mcp'
10+
import { parseRequest } from '@/lib/api/server'
911
import { getSession } from '@/lib/auth'
1012
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
1113
import {
@@ -55,10 +57,11 @@ function htmlClose(
5557
}
5658

5759
export const GET = withRouteHandler(async (request: NextRequest) => {
58-
const url = new URL(request.url)
59-
const state = url.searchParams.get('state')
60-
const code = url.searchParams.get('code')
61-
const errorParam = url.searchParams.get('error')
60+
const parsed = await parseRequest(mcpOauthCallbackContract, request, {})
61+
if (!parsed.success) {
62+
return htmlClose('Malformed authorization callback.', false, 'missing_params')
63+
}
64+
const { state, code, error: errorParam } = parsed.data.query
6265

6366
const initialRow = state ? await loadOauthRowByState(state).catch(() => null) : null
6467
const stateRowServerId = initialRow?.mcpServerId

apps/sim/lib/api/contracts/mcp.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,27 @@ export const startMcpOauthContract = defineRouteContract({
447447
},
448448
})
449449

450+
/**
451+
* Provider can return any subset depending on the outcome:
452+
* - success: `state` + `code`
453+
* - provider error: `error` + optional `error_description` + optional `state`
454+
* - malformed callback: nothing
455+
* All fields are optional so the route can render an HTML error page itself.
456+
*/
457+
export const mcpOauthCallbackQuerySchema = z.object({
458+
state: z.string().optional(),
459+
code: z.string().optional(),
460+
error: z.string().optional(),
461+
error_description: z.string().optional(),
462+
})
463+
464+
export const mcpOauthCallbackContract = defineRouteContract({
465+
method: 'GET',
466+
path: '/api/mcp/oauth/callback',
467+
query: mcpOauthCallbackQuerySchema,
468+
response: { mode: 'text' },
469+
})
470+
450471
export const getAllowedMcpDomainsContract = defineRouteContract({
451472
method: 'GET',
452473
path: '/api/settings/allowed-mcp-domains',

apps/sim/lib/mcp/oauth/provider.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,12 @@ export class SimMcpOauthProvider implements OAuthClientProvider {
6464
}
6565

6666
get clientMetadata(): OAuthClientMetadata {
67-
const meta: OAuthClientMetadata & { application_type?: string } = {
67+
const meta: OAuthClientMetadata = {
6868
client_name: 'Sim',
6969
redirect_uris: [this.redirectUrl],
7070
grant_types: ['authorization_code', 'refresh_token'],
7171
response_types: ['code'],
7272
token_endpoint_auth_method: this.preregistered?.clientSecret ? 'client_secret_post' : 'none',
73-
application_type: 'web',
7473
}
7574
if (this.scope) meta.scope = this.scope
7675
return meta

apps/sim/lib/mcp/service.ts

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,12 @@ class McpService {
195195
throw new Error('OAuth MCP server requires both userId and workspaceId')
196196
}
197197

198-
const initialRow = await getOrCreateOauthRow({
199-
mcpServerId: config.id,
200-
userId,
201-
workspaceId: config.workspaceId,
202-
})
203-
if (!initialRow.tokens) {
204-
throw new McpOauthAuthorizationRequiredError(config.id, config.name)
205-
}
206-
207-
// Re-read the row inside the lock so concurrent callers observe tokens
208-
// written by a predecessor refresh, rather than the stale snapshot loaded
209-
// before lock acquisition. Without this, the second caller's provider holds
210-
// a rotated-out refresh token and the SDK trips `invalid_grant`.
211-
return withMcpOauthRefreshLock(initialRow.id, async () => {
198+
// Load the row inside the refresh lock so concurrent callers observe tokens
199+
// written by a predecessor refresh, rather than a stale snapshot. Without
200+
// this, the second caller's provider would hold a rotated-out refresh token
201+
// and the SDK would trip `invalid_grant`. The lock is keyed on serverId
202+
// since the row is per-server.
203+
return withMcpOauthRefreshLock(config.id, async () => {
212204
const row = await getOrCreateOauthRow({
213205
mcpServerId: config.id,
214206
userId,

0 commit comments

Comments
 (0)