diff --git a/apps/sim/app/api/mcp/oauth/callback/route.test.ts b/apps/sim/app/api/mcp/oauth/callback/route.test.ts new file mode 100644 index 00000000000..d79c43e4a76 --- /dev/null +++ b/apps/sim/app/api/mcp/oauth/callback/route.test.ts @@ -0,0 +1,85 @@ +/** + * @vitest-environment node + */ +import { + authMockFns, + dbChainMock, + dbChainMockFns, + mcpOauthMock, + mcpOauthMockFns, + resetDbChainMock, + schemaMock, +} from '@sim/testing' +import { NextRequest } from 'next/server' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockMcpAuth, mockCreateSsrfGuardedMcpFetch, mockGuardedFetch, mockDiscoverServerTools } = + vi.hoisted(() => ({ + mockMcpAuth: vi.fn(), + mockCreateSsrfGuardedMcpFetch: vi.fn(), + mockGuardedFetch: vi.fn(), + mockDiscoverServerTools: vi.fn(), + })) + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('drizzle-orm', () => ({ + and: vi.fn(), + eq: vi.fn(), + isNull: vi.fn(), +})) +vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ + auth: mockMcpAuth, +})) +vi.mock('@/lib/mcp/oauth', () => mcpOauthMock) +vi.mock('@/lib/mcp/pinned-fetch', () => ({ + createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch, +})) +vi.mock('@/lib/mcp/service', () => ({ + mcpService: { discoverServerTools: mockDiscoverServerTools }, +})) + +import { GET } from './route' + +describe('MCP OAuth callback route', () => { + beforeEach(() => { + vi.clearAllMocks() + resetDbChainMock() + mockCreateSsrfGuardedMcpFetch.mockReturnValue(mockGuardedFetch) + authMockFns.mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mcpOauthMockFns.mockLoadOauthRowByState.mockResolvedValue({ + id: 'oauth-row-1', + mcpServerId: 'server-1', + userId: 'user-1', + workspaceId: 'workspace-1', + }) + dbChainMockFns.limit.mockResolvedValue([ + { + id: 'server-1', + url: 'https://mcp.example.com/mcp', + workspaceId: 'workspace-1', + }, + ]) + mcpOauthMockFns.mockLoadPreregisteredClient.mockResolvedValue(undefined) + mockMcpAuth.mockResolvedValue('AUTHORIZED') + mockDiscoverServerTools.mockResolvedValue(undefined) + }) + + it('performs the token exchange through the SSRF-guarded fetch', async () => { + const request = new NextRequest( + 'http://localhost:3000/api/mcp/oauth/callback?state=state-1&code=auth-code-1' + ) + + await GET(request) + + expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1) + expect(mockMcpAuth).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + serverUrl: 'https://mcp.example.com/mcp', + authorizationCode: 'auth-code-1', + fetchFn: mockGuardedFetch, + }) + ) + }) +}) diff --git a/apps/sim/app/api/mcp/oauth/callback/route.ts b/apps/sim/app/api/mcp/oauth/callback/route.ts index 08171fbc97c..0acfff85778 100644 --- a/apps/sim/app/api/mcp/oauth/callback/route.ts +++ b/apps/sim/app/api/mcp/oauth/callback/route.ts @@ -19,6 +19,7 @@ import { type McpOauthCallbackReason, SimMcpOauthProvider, } from '@/lib/mcp/oauth' +import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch' import { mcpService } from '@/lib/mcp/service' const logger = createLogger('McpOauthCallbackAPI') @@ -149,6 +150,7 @@ export const GET = withRouteHandler(async (request: NextRequest) => { result = await mcpAuth(provider, { serverUrl: server.url, authorizationCode: code, + fetchFn: createSsrfGuardedMcpFetch(), }) } catch (e) { logger.error('Token exchange failed during MCP OAuth callback', e)