Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions apps/sim/app/api/auth/oauth2/authorize/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import { db } from '@sim/db'
import { pendingCredentialDraft, user } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { generateId } from '@sim/utils/id'
import { and, eq, lt } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { authorizeOAuth2Contract } from '@/lib/api/contracts/oauth-connections'
import { parseRequest } from '@/lib/api/server'
import { auth, getSession } from '@/lib/auth/auth'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
import { getAllOAuthServices } from '@/lib/oauth/utils'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'

const logger = createLogger('OAuth2Authorize')

export const dynamic = 'force-dynamic'

const DRAFT_TTL_MS = 15 * 60 * 1000

/**
* Creates the pending credential draft at click time so its TTL starts when the
* user actually initiates the connect. Better Auth's `account.create.after` hook
* consumes this draft to materialize the real credential after the OAuth
* callback; starting the clock here guarantees the draft outlives the (≤5 min)
* OAuth round-trip rather than expiring mid-flow and silently producing no
* credential.
*/
async function createConnectDraft(params: {
userId: string
workspaceId: string
providerId: string
}): Promise<void> {
const { userId, workspaceId, providerId } = params

const service = getAllOAuthServices().find((s) => s.providerId === providerId)
const serviceName = service?.name ?? providerId

let displayName = serviceName
try {
const [row] = await db.select({ name: user.name }).from(user).where(eq(user.id, userId))
if (row?.name) {
displayName = `${row.name}'s ${serviceName}`
}
} catch {
// Fall back to service name only
}

const now = new Date()
const expiresAt = new Date(now.getTime() + DRAFT_TTL_MS)
await db
.delete(pendingCredentialDraft)
.where(
and(eq(pendingCredentialDraft.userId, userId), lt(pendingCredentialDraft.expiresAt, now))
)
await db
.insert(pendingCredentialDraft)
.values({
id: generateId(),
userId,
workspaceId,
providerId,
displayName,
expiresAt,
createdAt: now,
})
.onConflictDoUpdate({
target: [
pendingCredentialDraft.userId,
pendingCredentialDraft.providerId,
pendingCredentialDraft.workspaceId,
],
set: { displayName, expiresAt, createdAt: now },
})

logger.info('Created OAuth connect credential draft', { userId, workspaceId, providerId })
}

/**
* Browser-initiated entrypoint for linking a generic OAuth2 account.
*/
export const GET = withRouteHandler(async (request: NextRequest) => {
const baseUrl = getBaseUrl()

const session = await getSession()
if (!session?.user?.id) {
const loginUrl = new URL('/login', baseUrl)
loginUrl.searchParams.set('callbackUrl', request.nextUrl.pathname + request.nextUrl.search)
return NextResponse.redirect(loginUrl.toString())
}
const userId = session.user.id

const parsed = await parseRequest(authorizeOAuth2Contract, request, {})
if (!parsed.success) return parsed.response
const { providerId, workspaceId, callbackURL: requestedCallback } = parsed.data.query

const callbackURL = requestedCallback?.startsWith(`${baseUrl}/`)
? requestedCallback
: `${baseUrl}/workspace`

try {
const access = await checkWorkspaceAccess(workspaceId, userId)
if (!access.canWrite) {
logger.warn('Workspace write access denied for OAuth2 authorize', {
userId,
workspaceId,
providerId,
})
return NextResponse.redirect(`${baseUrl}/workspace?error=workspace_access_denied`)
}

// Create the draft before initiating the link so it is guaranteed to exist
// (and freshly clocked) when the OAuth callback's `account.create.after`
// hook runs. If this throws, we never start the OAuth flow.
await createConnectDraft({ userId, workspaceId, providerId })

const linkResponse = await auth.api.oAuth2LinkAccount({
body: { providerId, callbackURL },
headers: request.headers,
asResponse: true,
})

const payload = (await linkResponse.json().catch(() => null)) as { url?: string } | null
if (!linkResponse.ok || !payload?.url) {
logger.error('oAuth2LinkAccount did not return an authorization URL', {
providerId,
status: linkResponse.status,
})
return NextResponse.redirect(`${baseUrl}/workspace?error=oauth_link_failed`)
}

const response = NextResponse.redirect(payload.url)
// Forward the signed `state` cookie Better Auth set so it lands in the user's
// browser and is present when the provider redirects back to the callback.
const linkHeaders = linkResponse.headers as Headers & {
getSetCookie?: () => string[]
}
for (const cookie of linkHeaders.getSetCookie?.() ?? []) {
response.headers.append('set-cookie', cookie)
}
Comment thread
icecrasher321 marked this conversation as resolved.
return response
} catch (error) {
logger.error('Failed to initiate OAuth2 authorization', { providerId, error })
return NextResponse.redirect(`${baseUrl}/workspace?error=oauth_link_failed`)
}
})
14 changes: 14 additions & 0 deletions apps/sim/lib/api/contracts/oauth-connections.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { z } from 'zod'
import { workspaceIdSchema } from '@/lib/api/contracts/primitives'
import type {
ContractBody,
ContractBodyInput,
Expand Down Expand Up @@ -190,6 +191,19 @@ export const trelloCallbackContract = defineRouteContract({
response: { mode: 'text' },
})

export const authorizeOAuth2QuerySchema = z.object({
providerId: z.string().min(1, 'providerId is required'),
workspaceId: workspaceIdSchema,
callbackURL: z.string().min(1).optional(),
})

export const authorizeOAuth2Contract = defineRouteContract({
method: 'GET',
path: '/api/auth/oauth2/authorize',
query: authorizeOAuth2QuerySchema,
response: { mode: 'redirect' },
})

export type StoreTrelloTokenBody = ContractBody<typeof storeTrelloTokenContract>
export type StoreTrelloTokenBodyInput = ContractBodyInput<typeof storeTrelloTokenContract>
export type StoreTrelloTokenResponse = ContractJsonResponse<typeof storeTrelloTokenContract>
85 changes: 22 additions & 63 deletions apps/sim/lib/copilot/tools/handlers/oauth.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import { db } from '@sim/db'
import { pendingCredentialDraft, user } from '@sim/db/schema'
import { toError } from '@sim/utils/errors'
import { generateId } from '@sim/utils/id'
import { and, eq, lt } from 'drizzle-orm'
import type { ExecutionContext, ToolCallResult } from '@/lib/copilot/request/types'
import { ensureWorkspaceAccess } from '@/lib/copilot/tools/handlers/access'
import { getBaseUrl } from '@/lib/core/utils/urls'
Expand All @@ -20,7 +16,6 @@ export async function executeOAuthGetAuthLink(
}
await ensureWorkspaceAccess(context.workspaceId, context.userId, 'write')
const result = await generateOAuthLink(
context.userId,
context.workspaceId,
context.workflowId,
context.chatId,
Expand Down Expand Up @@ -69,14 +64,16 @@ export async function executeOAuthRequestAccess(
}

/**
* Resolves a human-friendly provider name to a providerId and generates the
* actual OAuth authorization URL via Better Auth's server-side API.
* Resolves a human-friendly provider name to a providerId and returns a
* browser-initiated authorize URL the user opens to connect the service.
*
* Steps: resolve provider → create credential draft → look up user session →
* call auth.api.oAuth2LinkAccount → return the real authorization URL.
* Steps: resolve provider → return the Sim `/api/auth/oauth2/authorize` URL.
* That endpoint (not this server-side handler) creates the credential draft and
* calls Better Auth, so the draft's TTL starts at click and the signed `state`
* cookie is planted in the user's browser and the OAuth callback's state check
* passes.
*/
async function generateOAuthLink(
userId: string,
workspaceId: string | undefined,
workflowId: string | undefined,
chatId: string | undefined,
Expand Down Expand Up @@ -127,58 +124,20 @@ async function generateOAuthLink(
}
}

let displayName = serviceName
try {
const [row] = await db.select({ name: user.name }).from(user).where(eq(user.id, userId))
if (row?.name) {
displayName = `${row.name}'s ${serviceName}`
}
} catch {
// Fall back to service name only
}

const now = new Date()
await db
.delete(pendingCredentialDraft)
.where(
and(eq(pendingCredentialDraft.userId, userId), lt(pendingCredentialDraft.expiresAt, now))
)
await db
.insert(pendingCredentialDraft)
.values({
id: generateId(),
userId,
workspaceId,
providerId,
displayName,
expiresAt: new Date(now.getTime() + 15 * 60 * 1000),
createdAt: now,
})
.onConflictDoUpdate({
target: [
pendingCredentialDraft.userId,
pendingCredentialDraft.providerId,
pendingCredentialDraft.workspaceId,
],
set: {
displayName,
expiresAt: new Date(now.getTime() + 15 * 60 * 1000),
createdAt: now,
},
})

const { auth } = await import('@/lib/auth/auth')
const { headers: getHeaders } = await import('next/headers')
const reqHeaders = await getHeaders()

const data = (await auth.api.oAuth2LinkAccount({
body: { providerId, callbackURL },
headers: reqHeaders,
})) as { url?: string; redirect?: boolean }

if (!data?.url) {
throw new Error('oAuth2LinkAccount did not return an authorization URL')
}
// Hand back a browser-initiated authorize URL rather than calling
// oAuth2LinkAccount here. Generating the link server-side would set Better
// Auth's signed `state` cookie on this server-to-server response instead of the
// user's browser, so the OAuth callback would fail with `state_mismatch`. The
// authorize endpoint runs the link inside the user's browser, planting the
// cookie correctly while keeping the callback's state check enabled.
//
// The pending credential draft is created by that authorize endpoint at click
// time (not here), so the draft's TTL starts when the user actually initiates
// the connect and reliably outlives the OAuth round-trip.
const authorizeUrl = new URL(`${baseUrl}/api/auth/oauth2/authorize`)
authorizeUrl.searchParams.set('providerId', providerId)
authorizeUrl.searchParams.set('workspaceId', workspaceId)
authorizeUrl.searchParams.set('callbackURL', callbackURL)

return { url: data.url, providerId, serviceName }
return { url: authorizeUrl.toString(), providerId, serviceName }
}
4 changes: 2 additions & 2 deletions scripts/check-api-validation-contracts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ const QUERY_HOOKS_DIR = path.join(ROOT, 'apps/sim/hooks/queries')
const SELECTOR_HOOKS_DIR = path.join(ROOT, 'apps/sim/hooks/selectors')

const BASELINE = {
totalRoutes: 761,
zodRoutes: 761,
totalRoutes: 762,
zodRoutes: 762,
nonZodRoutes: 0,
} as const

Expand Down
Loading