diff --git a/apps/sim/app/api/mcp/oauth/start/route.test.ts b/apps/sim/app/api/mcp/oauth/start/route.test.ts index fa8ebf26dc7..e4a5132a4fa 100644 --- a/apps/sim/app/api/mcp/oauth/start/route.test.ts +++ b/apps/sim/app/api/mcp/oauth/start/route.test.ts @@ -17,8 +17,10 @@ import { import { NextRequest } from 'next/server' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockMcpAuth } = vi.hoisted(() => ({ +const { mockMcpAuth, mockCreateSsrfGuardedMcpFetch, mockGuardedFetch } = vi.hoisted(() => ({ mockMcpAuth: vi.fn(), + mockCreateSsrfGuardedMcpFetch: vi.fn(), + mockGuardedFetch: vi.fn(), })) vi.mock('@sim/db', () => dbChainMock) @@ -31,6 +33,9 @@ vi.mock('drizzle-orm', () => ({ vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({ auth: mockMcpAuth, })) +vi.mock('@/lib/mcp/pinned-fetch', () => ({ + createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch, +})) vi.mock('@/lib/auth/hybrid', () => hybridAuthMock) vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) vi.mock('@/lib/mcp/oauth', () => mcpOauthMock) @@ -73,6 +78,21 @@ describe('MCP OAuth start route', () => { }) mcpOauthMockFns.mockLoadPreregisteredClient.mockResolvedValue(undefined) mockMcpAuth.mockRejectedValue(new McpOauthRedirectRequiredMock('https://mcp.exa.ai/authorize')) + mockCreateSsrfGuardedMcpFetch.mockReturnValue(mockGuardedFetch) + }) + + it('routes OAuth discovery through the SSRF-guarded fetch', async () => { + const request = new NextRequest( + 'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1' + ) + + await GET(request) + + expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1) + expect(mockMcpAuth).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ serverUrl: 'https://mcp.exa.ai/mcp', fetchFn: mockGuardedFetch }) + ) }) it('requires workspace write permission via MCP auth middleware', async () => { diff --git a/apps/sim/app/api/mcp/oauth/start/route.ts b/apps/sim/app/api/mcp/oauth/start/route.ts index c7619b9d555..3b228bd95c0 100644 --- a/apps/sim/app/api/mcp/oauth/start/route.ts +++ b/apps/sim/app/api/mcp/oauth/start/route.ts @@ -20,6 +20,7 @@ import { SimMcpOauthProvider, setOauthRowUser, } from '@/lib/mcp/oauth' +import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch' import { createMcpErrorResponse } from '@/lib/mcp/utils' const logger = createLogger('McpOauthStartAPI') @@ -129,7 +130,10 @@ export const GET = withRouteHandler( const provider = new SimMcpOauthProvider({ row, preregistered }) try { - const result = await mcpAuth(provider, { serverUrl: server.url }) + const result = await mcpAuth(provider, { + serverUrl: server.url, + fetchFn: createSsrfGuardedMcpFetch(), + }) if (result === 'AUTHORIZED') { return NextResponse.json({ status: 'already_authorized' }) }