diff --git a/apps/sim/app/api/guardrails/validate/route.test.ts b/apps/sim/app/api/guardrails/validate/route.test.ts new file mode 100644 index 00000000000..5c4f91b9109 --- /dev/null +++ b/apps/sim/app/api/guardrails/validate/route.test.ts @@ -0,0 +1,159 @@ +/** + * @vitest-environment node + */ +import { createMockRequest, hybridAuthMockFns, workflowAuthzMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockAuthorizeCredentialUse, mockCheckActorUsageLimits, mockValidateHallucination } = + vi.hoisted(() => ({ + mockAuthorizeCredentialUse: vi.fn(), + mockCheckActorUsageLimits: vi.fn(), + mockValidateHallucination: vi.fn(), + })) + +vi.mock('@/lib/auth/credential-access', () => ({ + authorizeCredentialUse: mockAuthorizeCredentialUse, +})) + +vi.mock('@/lib/billing/calculations/usage-monitor', () => ({ + checkActorUsageLimits: mockCheckActorUsageLimits, +})) + +vi.mock('@/lib/guardrails/validate_hallucination', () => ({ + validateHallucination: mockValidateHallucination, +})) + +vi.mock('@/lib/guardrails/validate_json', () => ({ + validateJson: vi.fn(() => ({ passed: true })), +})) + +vi.mock('@/lib/guardrails/validate_pii', () => ({ + validatePII: vi.fn(() => ({ passed: true })), +})) + +vi.mock('@/lib/guardrails/validate_regex', () => ({ + validateRegex: vi.fn(() => ({ passed: true })), +})) + +vi.mock('@/ee/access-control/utils/permission-check', () => ({ + assertPermissionsAllowed: vi.fn(), + ModelNotAllowedError: class ModelNotAllowedError extends Error {}, + ProviderNotAllowedError: class ProviderNotAllowedError extends Error {}, +})) + +import { POST } from '@/app/api/guardrails/validate/route' + +describe('POST /api/guardrails/validate', () => { + beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'session', + }) + workflowAuthzMockFns.mockAuthorizeWorkflowByWorkspacePermission.mockResolvedValue({ + allowed: true, + workflow: { id: 'wf-1', workspaceId: 'ws-1' }, + }) + mockCheckActorUsageLimits.mockResolvedValue({ isExceeded: false }) + mockValidateHallucination.mockResolvedValue({ passed: true, score: 8 }) + }) + + it('rejects a vertexCredential the caller does not have access to before calling validateHallucination', async () => { + mockAuthorizeCredentialUse.mockResolvedValue({ + ok: false, + error: 'You do not have access to this credential.', + }) + + const res = await POST( + createMockRequest('POST', { + validationType: 'hallucination', + input: 'test input', + knowledgeBaseId: 'kb-1', + model: 'vertex/gemini-2.5-pro', + workflowId: 'wf-1', + vertexCredential: 'someone-elses-account-id', + }) + ) + + expect(res.status).toBe(401) + expect(mockAuthorizeCredentialUse).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + credentialId: 'someone-elses-account-id', + workflowId: 'wf-1', + requireWorkflowIdForInternal: false, + }) + ) + expect(mockValidateHallucination).not.toHaveBeenCalled() + }) + + it('proceeds with hallucination validation when the caller has access to the vertexCredential', async () => { + mockAuthorizeCredentialUse.mockResolvedValue({ ok: true }) + + const res = await POST( + createMockRequest('POST', { + validationType: 'hallucination', + input: 'test input', + knowledgeBaseId: 'kb-1', + model: 'vertex/gemini-2.5-pro', + workflowId: 'wf-1', + vertexCredential: 'my-own-account-id', + }) + ) + + expect(res.status).toBe(200) + const json = await res.json() + expect(json.output.passed).toBe(true) + expect(mockAuthorizeCredentialUse).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ credentialId: 'my-own-account-id' }) + ) + expect(mockValidateHallucination).toHaveBeenCalled() + }) + + it('does not gate on vertexCredential for non-hallucination validation types', async () => { + const res = await POST( + createMockRequest('POST', { + validationType: 'json', + input: '{"a":1}', + }) + ) + + expect(res.status).toBe(200) + expect(mockAuthorizeCredentialUse).not.toHaveBeenCalled() + }) + + it('does not gate hallucination validation when no vertexCredential is supplied', async () => { + const res = await POST( + createMockRequest('POST', { + validationType: 'hallucination', + input: 'test input', + knowledgeBaseId: 'kb-1', + model: 'gpt-4o', + workflowId: 'wf-1', + }) + ) + + expect(res.status).toBe(200) + expect(mockAuthorizeCredentialUse).not.toHaveBeenCalled() + expect(mockValidateHallucination).toHaveBeenCalled() + }) + + it('does not gate on a leftover vertexCredential when the resolved model is not vertex', async () => { + const res = await POST( + createMockRequest('POST', { + validationType: 'hallucination', + input: 'test input', + knowledgeBaseId: 'kb-1', + model: 'gpt-4o', + workflowId: 'wf-1', + vertexCredential: 'someone-elses-account-id', + }) + ) + + expect(res.status).toBe(200) + expect(mockAuthorizeCredentialUse).not.toHaveBeenCalled() + expect(mockValidateHallucination).toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/guardrails/validate/route.ts b/apps/sim/app/api/guardrails/validate/route.ts index ca50b20b2d2..f4dea3d3367 100644 --- a/apps/sim/app/api/guardrails/validate/route.ts +++ b/apps/sim/app/api/guardrails/validate/route.ts @@ -3,6 +3,7 @@ import { authorizeWorkflowByWorkspacePermission } from '@sim/platform-authz/work import { type NextRequest, NextResponse } from 'next/server' import { guardrailsValidateContract } from '@/lib/api/contracts' import { parseRequest } from '@/lib/api/server' +import { authorizeCredentialUse } from '@/lib/auth/credential-access' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' import { checkActorUsageLimits } from '@/lib/billing/calculations/usage-monitor' import { generateRequestId } from '@/lib/core/utils/request' @@ -16,6 +17,7 @@ import { ModelNotAllowedError, ProviderNotAllowedError, } from '@/ee/access-control/utils/permission-check' +import { getProviderFromModel } from '@/providers/utils' const logger = createLogger('GuardrailsValidateAPI') @@ -187,6 +189,24 @@ export const POST = withRouteHandler(async (request: NextRequest) => { { status: 402 } ) } + + if (vertexCredential && getProviderFromModel(model) === 'vertex') { + const vertexCredAccess = await authorizeCredentialUse(request, { + credentialId: vertexCredential, + workflowId, + requireWorkflowIdForInternal: false, + }) + if (!vertexCredAccess.ok) { + logger.warn(`[${requestId}] Vertex credential access denied`, { + error: vertexCredAccess.error, + credentialId: vertexCredential, + }) + return NextResponse.json( + { error: vertexCredAccess.error || 'Unauthorized' }, + { status: 401 } + ) + } + } } const inputStr = convertInputToString(input)