diff --git a/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts b/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts index cd9c5523231..31d4defca0a 100644 --- a/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts +++ b/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts @@ -20,12 +20,17 @@ const { mockGenerateInternalToken, fetchMock } = vi.hoisted(() => ({ })) const mockGetUserEntityPermissions = permissionsMockFns.mockGetUserEntityPermissions +const MCP_BYTE_LIMIT = 10 * 1024 * 1024 +const MCP_TOOLS_LIST_LIMIT = 100 vi.mock('@sim/db', () => dbChainMock) vi.mock('drizzle-orm', () => ({ and: vi.fn(), + asc: vi.fn(), eq: vi.fn(), + gt: vi.fn(), isNull: vi.fn(), + sql: vi.fn(), })) vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) @@ -43,7 +48,7 @@ vi.mock('@/lib/core/execution-limits', () => ({ getMaxExecutionTimeout: () => 10_000, })) -import { GET, POST } from '@/app/api/mcp/serve/[serverId]/route' +import { DELETE, GET, POST } from '@/app/api/mcp/serve/[serverId]/route' describe('MCP Serve Route', () => { beforeEach(() => { @@ -101,7 +106,103 @@ describe('MCP Serve Route', () => { expect(response.status).toBe(401) }) - it('forwards X-API-Key for private server api_key auth', async () => { + it('allows unauthenticated GET metadata for public servers', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1') + const response = await GET(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body.name).toBe('Public Server') + expect(hybridAuthMockFns.mockCheckHybridAuth).not.toHaveBeenCalled() + }) + + it('authenticates private SSE-style GET before returning unsupported transport', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Private Server', + workspaceId: 'ws-1', + isPublic: false, + createdBy: 'owner-1', + }, + ]) + hybridAuthMockFns.mockCheckHybridAuth.mockResolvedValueOnce({ + success: false, + error: 'Unauthorized', + }) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + headers: { accept: 'text/event-stream' }, + }) + + const response = await GET(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + + expect(response.status).toBe(401) + }) + + it('returns 405 for authorized SSE-style GET', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Private Server', + workspaceId: 'ws-1', + isPublic: false, + createdBy: 'owner-1', + }, + ]) + hybridAuthMockFns.mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'user-1', + authType: 'session', + }) + mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + headers: { accept: 'text/event-stream' }, + }) + + const response = await GET(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(405) + expect(response.headers.get('allow')).toBe('GET, POST, DELETE') + expect(body.error.code).toBe('unsupported_transport') + }) + + it('requires authentication for DELETE even on public servers', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + hybridAuthMockFns.mockCheckHybridAuth.mockResolvedValueOnce({ + success: false, + error: 'Unauthorized', + }) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'DELETE', + }) + const response = await DELETE(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + + expect(response.status).toBe(401) + }) + + it('uses an internal bridge token for private server api_key auth', async () => { dbChainMockFns.limit .mockResolvedValueOnce([ { @@ -122,6 +223,7 @@ describe('MCP Serve Route', () => { apiKeyType: 'personal', }) mockGetUserEntityPermissions.mockResolvedValueOnce('write') + mockGenerateInternalToken.mockResolvedValueOnce('internal-token-user-1') fetchMock.mockResolvedValueOnce( new Response(JSON.stringify({ output: { ok: true } }), { status: 200, @@ -145,9 +247,10 @@ describe('MCP Serve Route', () => { expect(fetchMock).toHaveBeenCalledTimes(1) const fetchOptions = fetchMock.mock.calls[0][1] as RequestInit const headers = fetchOptions.headers as Record - expect(headers['X-API-Key']).toBe('pk_test_123') - expect(headers.Authorization).toBeUndefined() - expect(mockGenerateInternalToken).not.toHaveBeenCalled() + expect(headers.Authorization).toBe('Bearer internal-token-user-1') + expect(headers['X-Sim-MCP-Tool-Actor']).toBe('authenticated-user') + expect(headers['X-API-Key']).toBeUndefined() + expect(mockGenerateInternalToken).toHaveBeenCalledWith('user-1') }) it('forwards internal token for private server session auth', async () => { @@ -194,10 +297,586 @@ describe('MCP Serve Route', () => { const fetchOptions = fetchMock.mock.calls[0][1] as RequestInit const headers = fetchOptions.headers as Record expect(headers.Authorization).toBe('Bearer internal-token-user-1') + expect(headers['X-Sim-MCP-Tool-Actor']).toBeUndefined() expect(headers['X-API-Key']).toBeUndefined() expect(mockGenerateInternalToken).toHaveBeenCalledWith('user-1') }) + it('rejects oversized MCP request bodies before parsing JSON', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + headers: { 'content-length': String(MCP_BYTE_LIMIT + 1) }, + body: JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'ping' }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('MCP request body exceeds maximum size') + expect(fetchMock).not.toHaveBeenCalled() + }) + + it('rejects streamed MCP request bodies that exceed the cap without content-length', async () => { + const cancelSpy = vi.fn() + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + + const stream = new ReadableStream({ + start(controller) { + controller.enqueue(new Uint8Array(MCP_BYTE_LIMIT)) + controller.enqueue(new Uint8Array(1)) + }, + cancel: cancelSpy, + }) + const request = new Request('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: stream, + duplex: 'half', + } as RequestInit & { duplex: 'half' }) + const req = new NextRequest(request) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('MCP request body') + expect(cancelSpy).toHaveBeenCalled() + expect(fetchMock).not.toHaveBeenCalled() + }) + + it('rejects oversized tools/call arguments before internal fetch', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { payload: 'x'.repeat(MCP_BYTE_LIMIT) } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('MCP request body') + expect(fetchMock).not.toHaveBeenCalled() + }) + + it('cancels and rejects oversized workflow execution responses', async () => { + const cancelSpy = vi.fn() + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce( + new Response( + new ReadableStream({ + cancel: cancelSpy, + }), + { + status: 200, + headers: { 'content-length': String(MCP_BYTE_LIMIT + 1) }, + } + ) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('MCP workflow execution response') + expect(cancelSpy).toHaveBeenCalled() + }) + + it('cancels and rejects streamed workflow responses that exceed the cap', async () => { + const cancelSpy = vi.fn() + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce( + new Response( + new ReadableStream({ + start(controller) { + controller.enqueue(new Uint8Array(MCP_BYTE_LIMIT)) + controller.enqueue(new Uint8Array(1)) + }, + cancel: cancelSpy, + }), + { + status: 200, + headers: { 'content-length': '1' }, + } + ) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('MCP workflow execution response') + expect(cancelSpy).toHaveBeenCalled() + }) + + it('preserves recoverable workflow execution statuses through the MCP bridge', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce( + new Response( + JSON.stringify({ + success: false, + error: 'Workflow execution request body exceeds maximum size', + }), + { + status: 413, + headers: { 'Content-Type': 'application/json' }, + } + ) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.code).toBe(-32600) + expect(body.error.data.httpStatus).toBe(413) + const fetchOptions = fetchMock.mock.calls[0][1] as RequestInit + const headers = fetchOptions.headers as Record + expect(headers['X-Sim-MCP-Tool-Call']).toBe('true') + }) + + it('preserves upstream error status when workflow response is not JSON', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce(new Response('gateway timeout', { status: 408 })) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(408) + expect(body.error.data.httpStatus).toBe(408) + expect(body.error.data.retryable).toBe(true) + }) + + it('preserves falsy workflow outputs in MCP tool results', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ success: true, output: false }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body.result.content[0].text).toBe('false') + }) + + it('serializes missing workflow output without failing the MCP tool call', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify({ success: true }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body.result.content[0].text).toContain('"success": true') + }) + + it('serializes non-object workflow JSON responses from response blocks', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Private Server', + workspaceId: 'ws-1', + isPublic: false, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + hybridAuthMockFns.mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'user-1', + authType: 'api_key', + apiKeyType: 'personal', + }) + mockGetUserEntityPermissions.mockResolvedValueOnce('write') + fetchMock.mockResolvedValueOnce( + new Response(JSON.stringify(['a', 'b']), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + headers: { 'X-API-Key': 'pk_test_123' }, + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body.result.content[0].text).toBe(JSON.stringify(['a', 'b'], null, 2)) + }) + + it('rejects duplicate tool names instead of choosing an arbitrary workflow', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([ + { toolName: 'tool_a', workflowId: 'wf-1' }, + { toolName: 'tool_a', workflowId: 'wf-2' }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(409) + expect(body.error.data.code).toBe('duplicate_tool_name') + expect(fetchMock).not.toHaveBeenCalled() + }) + + it('aborts the internal workflow fetch when the MCP client disconnects', async () => { + const requestAbortController = new AbortController() + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([{ toolName: 'tool_a', workflowId: 'wf-1' }]) + .mockResolvedValueOnce([{ isDeployed: true }]) + fetchMock.mockImplementationOnce((_url, init: RequestInit) => { + const signal = init.signal as AbortSignal + return new Promise((_resolve, reject) => { + signal.addEventListener( + 'abort', + () => { + reject(Object.assign(new Error('aborted'), { name: 'AbortError' })) + }, + { once: true } + ) + requestAbortController.abort() + }) + }) + + const req = new NextRequest( + new Request('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ + jsonrpc: '2.0', + id: 1, + method: 'tools/call', + params: { name: 'tool_a', arguments: { q: 'test' } }, + }), + signal: requestAbortController.signal, + }) + ) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + + expect(response.status).toBe(499) + }) + + it('paginates tools/list by tool count', async () => { + const pageRows = Array.from({ length: MCP_TOOLS_LIST_LIMIT + 1 }, (_, index) => ({ + id: `tool-id-${String(index).padStart(3, '0')}`, + toolNameBytes: 10 + index, + toolDescriptionBytes: 0, + parameterSchemaBytes: 32, + })) + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce(pageRows) + .mockResolvedValueOnce( + pageRows.map((row, index) => ({ + id: row.id, + toolName: `tool_${index}`, + toolDescription: null, + parameterSchema: { type: 'object', properties: {} }, + })) + ) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body.result.tools).toHaveLength(MCP_TOOLS_LIST_LIMIT) + expect(body.result.nextCursor).toBe('tool-id-099') + }) + + it('bounds tools/list by stored metadata estimate', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([ + { + id: 'tool-id-1', + toolNameBytes: 6, + toolDescriptionBytes: MCP_BYTE_LIMIT + 1, + parameterSchemaBytes: 32, + }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('tools/list response is too large') + }) + + it('bounds tools/list by final serialized response size', async () => { + dbChainMockFns.limit + .mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([ + { + id: 'tool-id-1', + toolNameBytes: 6, + toolDescriptionBytes: 1, + parameterSchemaBytes: 32, + }, + ]) + .mockResolvedValueOnce([ + { + id: 'tool-id-1', + toolName: 'tool_a', + toolDescription: 'x'.repeat(MCP_BYTE_LIMIT), + parameterSchema: { type: 'object', properties: {} }, + }, + ]) + + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'tools/list' }), + }) + + const response = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error.message).toContain('tools/list response is too large') + }) + describe('initialize protocol version negotiation', () => { async function callInitialize(protocolVersion?: string) { dbChainMockFns.limit.mockResolvedValueOnce([ diff --git a/apps/sim/app/api/mcp/serve/[serverId]/route.ts b/apps/sim/app/api/mcp/serve/[serverId]/route.ts index d876dcd0ef2..8f910764b3d 100644 --- a/apps/sim/app/api/mcp/serve/[serverId]/route.ts +++ b/apps/sim/app/api/mcp/serve/[serverId]/route.ts @@ -20,7 +20,7 @@ import { import { db } from '@sim/db' import { workflow, workflowMcpServer, workflowMcpTool, workspace } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { and, eq, isNull } from 'drizzle-orm' +import { and, asc, eq, gt, isNull, sql } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { mcpJsonRpcNotificationSchema, @@ -28,15 +28,36 @@ import { mcpServeRouteParamsSchema, mcpToolCallParamsSchema, } from '@/lib/api/contracts/mcp' -import { type AuthResult, AuthType, checkHybridAuth } from '@/lib/auth/hybrid' +import { AuthType, checkHybridAuth } from '@/lib/auth/hybrid' import { generateInternalToken } from '@/lib/auth/internal' import { getMaxExecutionTimeout } from '@/lib/core/execution-limits' +import { + assertContentLengthWithinLimit, + assertKnownSizeWithinLimit, + isPayloadSizeLimitError, + readResponseTextWithLimit, + readStreamToBufferWithLimit, +} from '@/lib/core/utils/stream-limits' import { getInternalApiBaseUrl } from '@/lib/core/utils/urls' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { SIM_VIA_HEADER } from '@/lib/execution/call-chain' +import { + MAX_MCP_PARAMETER_SCHEMA_BYTES, + MAX_MCP_TOOLS_LIST_RESPONSE_BYTES, + MAX_MCP_TOOLS_PER_SERVER, + MAX_MCP_WORKFLOW_RESPONSE_BYTES, + MCP_TOOL_BRIDGE_ACTOR_HEADER, + MCP_TOOL_BRIDGE_HEADER, +} from '@/lib/mcp/constants' import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' const logger = createLogger('WorkflowMcpServeAPI') +const MAX_MCP_SERVE_BODY_BYTES = 10 * 1024 * 1024 +const MAX_MCP_WORKFLOW_REQUEST_BYTES = 10 * 1024 * 1024 +const MAX_MCP_TOOL_RESULT_TEXT_BYTES = 10 * 1024 * 1024 +const MAX_MCP_TOOLS_LIST_COUNT = MAX_MCP_TOOLS_PER_SERVER +const MAX_MCP_TOOLS_LIST_SCHEMA_BYTES = MAX_MCP_PARAMETER_SCHEMA_BYTES +const MB = 1024 * 1024 function negotiateProtocolVersion(rpcParams: unknown): string { const requested = @@ -56,9 +77,8 @@ interface RouteParams { } interface ExecuteAuthContext { - authType?: AuthResult['authType'] userId: string - apiKey?: string | null + useAuthenticatedUserAsActor: boolean } function createResponse(id: RequestId, result: unknown): JSONRPCResultResponse { @@ -69,12 +89,197 @@ function createResponse(id: RequestId, result: unknown): JSONRPCResultResponse { } } -function createError(id: RequestId, code: ErrorCode | number, message: string): JSONRPCError { +function createError( + id: RequestId, + code: ErrorCode | number, + message: string, + data?: unknown +): JSONRPCError { return { jsonrpc: '2.0', id, - error: { code, message }, + error: { code, message, ...(data !== undefined && { data }) }, + } +} + +function clientCancelledJsonRpcResponse(id: RequestId): NextResponse { + return NextResponse.json( + createError(id, ErrorCode.ConnectionClosed, 'Client cancelled request'), + { + status: 499, + } + ) +} + +function callerAbortedJsonRpcResponse( + id: RequestId, + abortSignal?: ManagedAbortSignal | null +): NextResponse | null { + return abortSignal?.isCallerAborted() ? clientCancelledJsonRpcResponse(id) : null +} + +function limitMessage(label: string, maxBytes: number): string { + return `${label} exceeds maximum size of ${Math.round(maxBytes / MB)}MB` +} + +async function readJsonRpcBody(request: NextRequest): Promise { + assertContentLengthWithinLimit(request.headers, MAX_MCP_SERVE_BODY_BYTES, 'MCP request body') + const buffer = await readStreamToBufferWithLimit(request.body, { + maxBytes: MAX_MCP_SERVE_BODY_BYTES, + label: 'MCP request body', + signal: request.signal, + }) + return JSON.parse(buffer.toString('utf-8')) +} + +interface ManagedAbortSignal { + signal: AbortSignal + cleanup: () => void + isCallerAborted: () => boolean + isTimedOut: () => boolean +} + +function createManagedAbortSignal( + parentSignal: AbortSignal, + timeoutMs: number +): ManagedAbortSignal { + const controller = new AbortController() + let callerAborted = false + let timedOut = false + + const timeoutId = setTimeout(() => { + timedOut = true + controller.abort(new Error(`MCP workflow execution timed out after ${timeoutMs}ms`)) + }, timeoutMs) + + const abortFromParent = () => { + callerAborted = true + controller.abort(parentSignal.reason ?? new Error('MCP client disconnected')) + } + + if (parentSignal.aborted) { + abortFromParent() + } else { + parentSignal.addEventListener('abort', abortFromParent, { once: true }) + } + + return { + signal: controller.signal, + cleanup: () => { + clearTimeout(timeoutId) + parentSignal.removeEventListener('abort', abortFromParent) + }, + isCallerAborted: () => callerAborted || parentSignal.aborted, + isTimedOut: () => timedOut, + } +} + +function serializeToolText(value: unknown): string { + const text = JSON.stringify(value, null, 2) ?? 'null' + assertKnownSizeWithinLimit( + Buffer.byteLength(text, 'utf-8'), + MAX_MCP_TOOL_RESULT_TEXT_BYTES, + 'MCP tool result text' + ) + return text +} + +function createJsonRpcResponseWithLimit( + id: RequestId, + result: unknown, + maxBytes: number, + label: string +): NextResponse { + const responseBody = createResponse(id, result) + const responseText = JSON.stringify(responseBody) + assertKnownSizeWithinLimit(Buffer.byteLength(responseText, 'utf-8'), maxBytes, label) + return new NextResponse(responseText, { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }) +} + +function toToolInputSchema(schema: unknown): Partial { + if (!schema || typeof schema !== 'object' || Array.isArray(schema)) return {} + + const candidate = schema as Record + const properties = + candidate.properties && + typeof candidate.properties === 'object' && + !Array.isArray(candidate.properties) + ? (candidate.properties as Tool['inputSchema']['properties']) + : {} + const required = Array.isArray(candidate.required) + ? candidate.required.filter((entry): entry is string => typeof entry === 'string') + : undefined + + return { + properties, + ...(required && required.length > 0 && { required }), + } +} + +function isJsonObject(value: unknown): value is Record { + return value !== null && typeof value === 'object' && !Array.isArray(value) +} + +function parseJsonValue(text: string): { success: true; value: unknown } | { success: false } { + if (!text) return { success: true, value: {} } + try { + return { success: true, value: JSON.parse(text) } + } catch { + return { success: false } + } +} + +function hasResponseField(value: Record, property: string): boolean { + return Object.hasOwn(value, property) +} + +function getWorkflowErrorStatus(status: number): number { + return [400, 401, 403, 404, 408, 409, 413, 429, 499, 503].includes(status) ? status : 500 +} + +function getWorkflowErrorCode(status: number, executeResult: Record): ErrorCode { + if (status === 499) return ErrorCode.ConnectionClosed + if (status === 400) return ErrorCode.InvalidParams + if (status === 413 && executeResult.code !== 'workflow_response_too_large') { + return ErrorCode.InvalidRequest } + return ErrorCode.InternalError +} + +function getToolsListCursor(rpcParams: unknown): string | undefined { + if (!rpcParams || typeof rpcParams !== 'object' || !('cursor' in rpcParams)) return undefined + const cursor = (rpcParams as { cursor?: unknown }).cursor + return typeof cursor === 'string' && cursor.length > 0 ? cursor : undefined +} + +async function getDuplicateToolName(serverId: string): Promise { + const [duplicate] = await db + .select({ toolName: workflowMcpTool.toolName }) + .from(workflowMcpTool) + .where(and(eq(workflowMcpTool.serverId, serverId), isNull(workflowMcpTool.archivedAt))) + .groupBy(workflowMcpTool.toolName) + .having(sql`count(*) > 1`) + .limit(1) + + return duplicate?.toolName ?? null +} + +async function readWorkflowExecutionResult( + response: Response, + signal: AbortSignal +): Promise { + const text = await readResponseTextWithLimit(response, { + maxBytes: MAX_MCP_WORKFLOW_RESPONSE_BYTES, + label: 'MCP workflow execution response', + signal, + }) + const parsed = parseJsonValue(text) + if (parsed.success) return parsed.value + if (!response.ok) return { error: response.statusText || 'Workflow execution failed' } + throw new Error('Invalid workflow execution response') } async function getServer(serverId: string) { @@ -100,6 +305,64 @@ async function getServer(serverId: string) { return server } +type WorkflowMcpServeServer = NonNullable>> + +async function authorizeMcpServeRequest( + request: NextRequest, + server: WorkflowMcpServeServer, + options: { requireAuthForPublic?: boolean } = {} +): Promise<{ response?: NextResponse; executeAuthContext?: ExecuteAuthContext }> { + if (server.isPublic && !options.requireAuthForPublic) return {} + + const auth = await checkHybridAuth(request, { requireWorkflowId: false }) + if (!auth.success || !auth.userId) { + return { response: NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) } + } + + if (server.isPublic) return {} + + if (auth.apiKeyType === 'workspace' && auth.workspaceId !== server.workspaceId) { + return { response: NextResponse.json({ error: 'Forbidden' }, { status: 403 }) } + } + + const workspacePermission = await getUserEntityPermissions( + auth.userId, + 'workspace', + server.workspaceId + ) + if (workspacePermission === null) { + return { response: NextResponse.json({ error: 'Forbidden' }, { status: 403 }) } + } + + return { + executeAuthContext: { + userId: auth.userId, + useAuthenticatedUserAsActor: + auth.authType === AuthType.API_KEY && auth.apiKeyType === 'personal', + }, + } +} + +function unsupportedSseTransportResponse(): NextResponse { + return NextResponse.json( + { + error: { + code: 'unsupported_transport', + message: 'SSE transport is not supported for workflow MCP servers', + supportedTransports: ['streamable-http'], + allowedMethods: ['GET', 'POST', 'DELETE'], + }, + }, + { + status: 405, + headers: { + Allow: 'GET, POST, DELETE', + 'X-MCP-Supported-Transport': 'streamable-http', + }, + } + ) +} + export const GET = withRouteHandler( async (request: NextRequest, { params }: { params: Promise }) => { try { @@ -109,24 +372,11 @@ export const GET = withRouteHandler( return NextResponse.json({ error: 'Server not found' }, { status: 404 }) } - if (!server.isPublic) { - const auth = await checkHybridAuth(request, { requireWorkflowId: false }) - if (!auth.success || !auth.userId) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } + const authResult = await authorizeMcpServeRequest(request, server) + if (authResult.response) return authResult.response - if (auth.apiKeyType === 'workspace' && auth.workspaceId !== server.workspaceId) { - return NextResponse.json({ error: 'Forbidden' }, { status: 403 }) - } - - const workspacePermission = await getUserEntityPermissions( - auth.userId, - 'workspace', - server.workspaceId - ) - if (workspacePermission === null) { - return NextResponse.json({ error: 'Forbidden' }, { status: 403 }) - } + if (request.headers.get('accept')?.includes('text/event-stream')) { + return unsupportedSseTransportResponse() } return NextResponse.json({ @@ -152,36 +402,29 @@ export const POST = withRouteHandler( } let executeAuthContext: ExecuteAuthContext | null = null - if (!server.isPublic) { - const auth = await checkHybridAuth(request, { requireWorkflowId: false }) - if (!auth.success || !auth.userId) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - if (auth.apiKeyType === 'workspace' && auth.workspaceId !== server.workspaceId) { - return NextResponse.json({ error: 'Forbidden' }, { status: 403 }) - } - - const workspacePermission = await getUserEntityPermissions( - auth.userId, - 'workspace', - server.workspaceId - ) - if (workspacePermission === null) { - return NextResponse.json({ error: 'Forbidden' }, { status: 403 }) - } - - executeAuthContext = { - authType: auth.authType, - userId: auth.userId, - apiKey: auth.authType === AuthType.API_KEY ? request.headers.get('X-API-Key') : null, - } - } + const authResult = await authorizeMcpServeRequest(request, server) + if (authResult.response) return authResult.response + executeAuthContext = authResult.executeAuthContext ?? null let body: unknown try { - body = await request.json() - } catch { + body = await readJsonRpcBody(request) + } catch (error) { + if (isPayloadSizeLimitError(error)) { + logger.warn('MCP request body exceeded size limit', { + maxBytes: error.maxBytes, + observedBytes: error.observedBytes, + }) + return NextResponse.json( + createError( + 0, + ErrorCode.InvalidRequest, + limitMessage('MCP request body', MAX_MCP_SERVE_BODY_BYTES) + ), + { status: 413 } + ) + } + if (request.signal.aborted) return clientCancelledJsonRpcResponse(0) return NextResponse.json(createError(0, ErrorCode.ParseError, 'Invalid JSON body'), { status: 400, }) @@ -238,7 +481,7 @@ export const POST = withRouteHandler( return NextResponse.json(createResponse(id, {})) case 'tools/list': - return handleToolsList(id, serverId) + return handleToolsList(id, serverId, rpcParams) case 'tools/call': { const paramsValidation = mcpToolCallParamsSchema.safeParse(rpcParams) @@ -257,7 +500,8 @@ export const POST = withRouteHandler( paramsValidation.data, executeAuthContext, server.isPublic ? server.createdBy : undefined, - request.headers.get(SIM_VIA_HEADER) + request.headers.get(SIM_VIA_HEADER), + request.signal ) } @@ -278,20 +522,98 @@ export const POST = withRouteHandler( } ) -async function handleToolsList(id: RequestId, serverId: string): Promise { +async function handleToolsList( + id: RequestId, + serverId: string, + rpcParams: unknown +): Promise { try { + const duplicateToolName = await getDuplicateToolName(serverId) + if (duplicateToolName) { + return NextResponse.json( + createError(id, ErrorCode.InvalidRequest, 'MCP server has duplicate tool names', { + code: 'duplicate_tool_name', + toolName: duplicateToolName, + recovery: 'Rename or remove duplicate workflow MCP tools before listing this server', + }), + { status: 409 } + ) + } + + const cursor = getToolsListCursor(rpcParams) + const pageCondition = cursor ? gt(workflowMcpTool.id, cursor) : undefined + const toolSizes = await db + .select({ + id: workflowMcpTool.id, + toolNameBytes: sql`octet_length(${workflowMcpTool.toolName})`, + toolDescriptionBytes: sql`coalesce(octet_length(${workflowMcpTool.toolDescription}), 0)`, + parameterSchemaBytes: sql`octet_length(${workflowMcpTool.parameterSchema}::text)`, + }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.serverId, serverId), + isNull(workflowMcpTool.archivedAt), + pageCondition + ) + ) + .orderBy(asc(workflowMcpTool.id)) + .limit(MAX_MCP_TOOLS_LIST_COUNT + 1) + + const pageSizes = toolSizes.slice(0, MAX_MCP_TOOLS_LIST_COUNT) + + let estimatedSchemaBytes = 0 + let estimatedMetadataBytes = 0 + for (const toolSize of pageSizes) { + estimatedSchemaBytes += Number(toolSize.parameterSchemaBytes) || 0 + estimatedMetadataBytes += + (Number(toolSize.toolNameBytes) || 0) + + (Number(toolSize.toolDescriptionBytes) || 0) + + (Number(toolSize.parameterSchemaBytes) || 0) + assertKnownSizeWithinLimit( + estimatedSchemaBytes, + MAX_MCP_TOOLS_LIST_SCHEMA_BYTES, + 'MCP tools/list schemas' + ) + assertKnownSizeWithinLimit( + estimatedMetadataBytes, + MAX_MCP_TOOLS_LIST_RESPONSE_BYTES, + 'MCP tools/list stored metadata' + ) + } + const tools = await db .select({ + id: workflowMcpTool.id, toolName: workflowMcpTool.toolName, toolDescription: workflowMcpTool.toolDescription, parameterSchema: workflowMcpTool.parameterSchema, }) .from(workflowMcpTool) - .where(and(eq(workflowMcpTool.serverId, serverId), isNull(workflowMcpTool.archivedAt))) + .where( + and( + eq(workflowMcpTool.serverId, serverId), + isNull(workflowMcpTool.archivedAt), + pageCondition + ) + ) + .orderBy(asc(workflowMcpTool.id)) + .limit(MAX_MCP_TOOLS_LIST_COUNT + 1) + const hasNextPage = tools.length > MAX_MCP_TOOLS_LIST_COUNT + const pageTools = tools.slice(0, MAX_MCP_TOOLS_LIST_COUNT) + const nextCursor = hasNextPage ? pageTools.at(-1)?.id : undefined + let schemaBytes = 0 const result: ListToolsResult = { - tools: tools.map((tool) => { - const schema = tool.parameterSchema as Partial | null + tools: pageTools.map((tool) => { + const schema = toToolInputSchema(tool.parameterSchema) + const schemaByteLength = Buffer.byteLength(JSON.stringify(schema ?? {}), 'utf-8') + schemaBytes += schemaByteLength + assertKnownSizeWithinLimit( + schemaBytes, + MAX_MCP_TOOLS_LIST_SCHEMA_BYTES, + 'MCP tools/list schemas' + ) return { name: tool.toolName, description: tool.toolDescription || `Execute workflow: ${tool.toolName}`, @@ -302,10 +624,32 @@ async function handleToolsList(id: RequestId, serverId: string): Promise } | undefined, executeAuthContext?: ExecuteAuthContext | null, publicServerOwnerId?: string, - simViaHeader?: string | null + simViaHeader?: string | null, + requestSignal?: AbortSignal ): Promise { + let abortSignal: ManagedAbortSignal | null = null try { if (!params?.name) { return NextResponse.json(createError(id, ErrorCode.InvalidParams, 'Tool name required'), { status: 400, }) } + abortSignal = createManagedAbortSignal( + requestSignal ?? new AbortController().signal, + getMaxExecutionTimeout() + ) + const abortedBeforeToolLookup = callerAbortedJsonRpcResponse(id, abortSignal) + if (abortedBeforeToolLookup) return abortedBeforeToolLookup - const [tool] = await db + const matchingTools = await db .select({ toolName: workflowMcpTool.toolName, workflowId: workflowMcpTool.workflowId, @@ -341,7 +693,21 @@ async function handleToolsCall( isNull(workflowMcpTool.archivedAt) ) ) - .limit(1) + .orderBy(asc(workflowMcpTool.id)) + .limit(2) + const abortedAfterToolLookup = callerAbortedJsonRpcResponse(id, abortSignal) + if (abortedAfterToolLookup) return abortedAfterToolLookup + if (matchingTools.length > 1) { + return NextResponse.json( + createError(id, ErrorCode.InvalidRequest, `Duplicate tool name: ${params.name}`, { + code: 'duplicate_tool_name', + toolName: params.name, + recovery: 'Rename or remove duplicate workflow MCP tools before calling this tool', + }), + { status: 409 } + ) + } + const [tool] = matchingTools if (!tool) { return NextResponse.json( createError(id, ErrorCode.InvalidParams, `Tool not found: ${params.name}`), @@ -356,6 +722,8 @@ async function handleToolsCall( .from(workflow) .where(and(eq(workflow.id, tool.workflowId), isNull(workflow.archivedAt))) .limit(1) + const abortedAfterWorkflowLookup = callerAbortedJsonRpcResponse(id, abortSignal) + if (abortedAfterWorkflowLookup) return abortedAfterWorkflowLookup if (!wf?.isDeployed) { return NextResponse.json( @@ -367,17 +735,22 @@ async function handleToolsCall( } const executeUrl = `${getInternalApiBaseUrl()}/api/workflows/${tool.workflowId}/execute` - const headers: Record = { 'Content-Type': 'application/json' } + const headers: Record = { + 'Content-Type': 'application/json', + [MCP_TOOL_BRIDGE_HEADER]: 'true', + } + + const abortedBeforeExecute = callerAbortedJsonRpcResponse(id, abortSignal) + if (abortedBeforeExecute) return abortedBeforeExecute if (publicServerOwnerId) { const internalToken = await generateInternalToken(publicServerOwnerId) headers.Authorization = `Bearer ${internalToken}` } else if (executeAuthContext) { - if (executeAuthContext.authType === AuthType.API_KEY && executeAuthContext.apiKey) { - headers['X-API-Key'] = executeAuthContext.apiKey - } else { - const internalToken = await generateInternalToken(executeAuthContext.userId) - headers.Authorization = `Bearer ${internalToken}` + const internalToken = await generateInternalToken(executeAuthContext.userId) + headers.Authorization = `Bearer ${internalToken}` + if (executeAuthContext.useAuthenticatedUserAsActor) { + headers[MCP_TOOL_BRIDGE_ACTOR_HEADER] = 'authenticated-user' } } @@ -387,39 +760,111 @@ async function handleToolsCall( logger.info(`Executing workflow ${tool.workflowId} via MCP tool ${params.name}`) + const workflowRequestBody = JSON.stringify({ + input: params.arguments || {}, + triggerType: 'mcp', + includeFileBase64: false, + }) + assertKnownSizeWithinLimit( + Buffer.byteLength(workflowRequestBody, 'utf-8'), + MAX_MCP_WORKFLOW_REQUEST_BYTES, + 'MCP workflow execution request body' + ) const response = await fetch(executeUrl, { method: 'POST', headers, - body: JSON.stringify({ input: params.arguments || {}, triggerType: 'mcp' }), - signal: AbortSignal.timeout(getMaxExecutionTimeout()), + body: workflowRequestBody, + signal: abortSignal.signal, }) - const executeResult = await response.json() + const executeResult = await readWorkflowExecutionResult(response, abortSignal.signal) + const executeResultObject = isJsonObject(executeResult) ? executeResult : null if (!response.ok) { + const errorMessage = + typeof executeResultObject?.error === 'string' + ? executeResultObject.error + : 'Workflow execution failed' + const status = getWorkflowErrorStatus(response.status) + const responseHeaders: Record = {} + const retryAfter = response.headers.get('retry-after') + if (retryAfter) responseHeaders['Retry-After'] = retryAfter return NextResponse.json( createError( id, - ErrorCode.InternalError, - executeResult.error || 'Workflow execution failed' + getWorkflowErrorCode(response.status, executeResultObject ?? {}), + errorMessage, + { + httpStatus: response.status, + retryable: [408, 429, 503].includes(response.status), + code: + typeof executeResultObject?.code === 'string' ? executeResultObject.code : undefined, + } ), - { status: 500 } + { status, headers: responseHeaders } ) } + const toolOutput = + executeResultObject?.success === false + ? executeResult + : executeResultObject && hasResponseField(executeResultObject, 'output') + ? executeResultObject.output + : executeResult const result: CallToolResult = { - content: [ - { type: 'text', text: JSON.stringify(executeResult.output || executeResult, null, 2) }, - ], - isError: executeResult.success === false, + content: [{ type: 'text', text: serializeToolText(toolOutput) }], + isError: executeResultObject?.success === false, } - return NextResponse.json(createResponse(id, result)) + return createJsonRpcResponseWithLimit( + id, + result, + MAX_MCP_WORKFLOW_RESPONSE_BYTES, + 'MCP tool call response' + ) } catch (error) { + if (abortSignal?.isTimedOut()) { + return NextResponse.json( + createError(id, ErrorCode.InternalError, 'Tool execution timed out', { + code: 'timeout', + retryable: true, + }), + { + status: 408, + } + ) + } + const abortedAfterExecute = callerAbortedJsonRpcResponse(id, abortSignal) + if (abortedAfterExecute) return abortedAfterExecute + if (isPayloadSizeLimitError(error)) { + logger.warn('MCP tool call exceeded size limit', { + maxBytes: error.maxBytes, + observedBytes: error.observedBytes, + label: error.label, + }) + return NextResponse.json( + createError( + id, + error.label === 'MCP workflow execution request body' + ? ErrorCode.InvalidParams + : ErrorCode.InternalError, + limitMessage(error.label, error.maxBytes), + { + code: 'payload_too_large', + maxBytes: error.maxBytes, + observedBytes: error.observedBytes, + retryable: false, + } + ), + { status: 413 } + ) + } logger.error('Error calling tool:', error) return NextResponse.json(createError(id, ErrorCode.InternalError, 'Tool execution failed'), { status: 500, }) + } finally { + abortSignal?.cleanup() } } @@ -432,21 +877,10 @@ export const DELETE = withRouteHandler( return NextResponse.json({ error: 'Server not found' }, { status: 404 }) } - const auth = await checkHybridAuth(request, { requireWorkflowId: false }) - if (!auth.success || !auth.userId) { - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - if (!server.isPublic) { - const workspacePermission = await getUserEntityPermissions( - auth.userId, - 'workspace', - server.workspaceId - ) - if (workspacePermission === null) { - return NextResponse.json({ error: 'Forbidden' }, { status: 403 }) - } - } + const authResult = await authorizeMcpServeRequest(request, server, { + requireAuthForPublic: true, + }) + if (authResult.response) return authResult.response logger.info(`MCP session terminated for server ${serverId}`) return new NextResponse(null, { status: 204 }) diff --git a/apps/sim/app/api/mcp/servers/[id]/route.ts b/apps/sim/app/api/mcp/servers/[id]/route.ts index 4242fdef119..b008e8ae971 100644 --- a/apps/sim/app/api/mcp/servers/[id]/route.ts +++ b/apps/sim/app/api/mcp/servers/[id]/route.ts @@ -3,7 +3,11 @@ import { toError } from '@sim/utils/errors' import type { NextRequest } from 'next/server' import { updateMcpServerBodySchema } from '@/lib/api/contracts/mcp' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performUpdateMcpServer } from '@/lib/mcp/orchestration' import { createMcpErrorResponse, @@ -28,7 +32,7 @@ export const PATCH = withRouteHandler( try { const { id: serverId } = await params - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = updateMcpServerBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -82,6 +86,8 @@ export const PATCH = withRouteHandler( server: { ...rest, hasOauthClientSecret: !!_secret }, }) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error updating MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to update MCP server', 500) } diff --git a/apps/sim/app/api/mcp/servers/route.ts b/apps/sim/app/api/mcp/servers/route.ts index 1d02caeef74..cfcc099eb85 100644 --- a/apps/sim/app/api/mcp/servers/route.ts +++ b/apps/sim/app/api/mcp/servers/route.ts @@ -7,7 +7,11 @@ import type { NextRequest } from 'next/server' import { createMcpServerBodySchema, deleteMcpServerByQuerySchema } from '@/lib/api/contracts/mcp' import { validationErrorResponse } from '@/lib/api/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performCreateMcpServer, performDeleteMcpServer } from '@/lib/mcp/orchestration' import { createMcpErrorResponse, @@ -55,7 +59,7 @@ export const POST = withRouteHandler( withMcpAuth('write')( async (request: NextRequest, { userId, userName, userEmail, workspaceId, requestId }) => { try { - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = createMcpServerBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -120,6 +124,8 @@ export const POST = withRouteHandler( result.updated ? 200 : 201 ) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error registering MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to register MCP server', 500) } diff --git a/apps/sim/app/api/mcp/servers/test-connection/route.ts b/apps/sim/app/api/mcp/servers/test-connection/route.ts index c017de7a34c..b570f8f8ff0 100644 --- a/apps/sim/app/api/mcp/servers/test-connection/route.ts +++ b/apps/sim/app/api/mcp/servers/test-connection/route.ts @@ -11,7 +11,11 @@ import { validateMcpDomain, validateMcpServerSsrf, } from '@/lib/mcp/domain-check' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { detectMcpAuthType } from '@/lib/mcp/oauth' import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config' import type { McpAuthType, McpTransport } from '@/lib/mcp/types' @@ -64,7 +68,7 @@ function sanitizeConnectionError(error: unknown): string { export const POST = withRouteHandler( withMcpAuth('write')(async (request: NextRequest, { userId, workspaceId, requestId }) => { try { - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = mcpServerTestBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -235,6 +239,8 @@ export const POST = withRouteHandler( return createMcpSuccessResponse(result, result.success ? 200 : 400) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error testing MCP server connection:`, error) return createMcpErrorResponse(toError(error), 'Failed to test server connection', 500) } diff --git a/apps/sim/app/api/mcp/tools/discover/route.ts b/apps/sim/app/api/mcp/tools/discover/route.ts index 612788b4875..84acdad0b3d 100644 --- a/apps/sim/app/api/mcp/tools/discover/route.ts +++ b/apps/sim/app/api/mcp/tools/discover/route.ts @@ -1,18 +1,55 @@ import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { createLogger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' import type { NextRequest } from 'next/server' import { mcpToolDiscoveryQuerySchema, refreshMcpToolsBodySchema } from '@/lib/api/contracts/mcp' import { validationErrorResponse } from '@/lib/api/server' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { mcpService } from '@/lib/mcp/service' import { McpOauthAuthorizationRequiredError, type McpToolDiscoveryResponse } from '@/lib/mcp/types' import { categorizeError, createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' const logger = createLogger('McpToolDiscoveryAPI') +const MCP_REFRESH_DISCOVERY_CONCURRENCY = 5 export const dynamic = 'force-dynamic' +async function settleWithConcurrency( + items: T[], + concurrency: number, + task: (item: T) => Promise +): Promise>> { + const results: Array | undefined> = new Array(items.length) + let nextIndex = 0 + + const workers = Array.from({ length: Math.min(concurrency, items.length) }, async () => { + while (nextIndex < items.length) { + const index = nextIndex + nextIndex += 1 + try { + results[index] = { status: 'fulfilled', value: await task(items[index]) } + } catch (reason) { + results[index] = { status: 'rejected', reason } + } + } + }) + + await Promise.all(workers) + + return results.map( + (result) => + result ?? { + status: 'rejected', + reason: new Error('MCP refresh discovery task did not run'), + } + ) +} + export const GET = withRouteHandler( withMcpAuth('read')(async (request: NextRequest, { userId, workspaceId, requestId }) => { try { @@ -63,7 +100,7 @@ export const GET = withRouteHandler( export const POST = withRouteHandler( withMcpAuth('read')(async (request: NextRequest, { userId, workspaceId, requestId }) => { try { - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = refreshMcpToolsBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -74,11 +111,13 @@ export const POST = withRouteHandler( logger.info(`[${requestId}] Refreshing tools for ${serverIds.length} servers`) - const results = await Promise.allSettled( - serverIds.map(async (serverId: string) => { + const results = await settleWithConcurrency( + serverIds, + MCP_REFRESH_DISCOVERY_CONCURRENCY, + async (serverId: string) => { const tools = await mcpService.discoverServerTools(userId, serverId, workspaceId, true) return { serverId, toolCount: tools.length } - }) + } ) const successes: Array<{ serverId: string; toolCount: number }> = [] @@ -91,7 +130,7 @@ export const POST = withRouteHandler( } else { failures.push({ serverId, - error: result.reason instanceof Error ? result.reason.message : 'Unknown error', + error: getErrorMessage(result.reason, 'Unknown error'), }) } }) @@ -107,6 +146,8 @@ export const POST = withRouteHandler( }, }) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse if ( error instanceof McpOauthAuthorizationRequiredError || error instanceof UnauthorizedError diff --git a/apps/sim/app/api/mcp/tools/execute/route.ts b/apps/sim/app/api/mcp/tools/execute/route.ts index 8599a5fcadf..bb4e3650fc7 100644 --- a/apps/sim/app/api/mcp/tools/execute/route.ts +++ b/apps/sim/app/api/mcp/tools/execute/route.ts @@ -8,7 +8,11 @@ import { getExecutionTimeout } from '@/lib/core/execution-limits' import type { SubscriptionPlan } from '@/lib/core/rate-limiter/types' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { SIM_VIA_HEADER } from '@/lib/execution/call-chain' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { McpOauthRedirectRequired } from '@/lib/mcp/oauth' import { mcpService } from '@/lib/mcp/service' import { @@ -53,7 +57,7 @@ export const POST = withRouteHandler( withMcpAuth('read')(async (request: NextRequest, { userId, workspaceId, requestId }) => { let serverId: string | undefined try { - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = mcpToolExecutionBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -235,6 +239,8 @@ export const POST = withRouteHandler( return createMcpSuccessResponse(transformedResult) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse if ( error instanceof McpOauthAuthorizationRequiredError || error instanceof McpOauthRedirectRequired || diff --git a/apps/sim/app/api/mcp/workflow-servers/[id]/route.ts b/apps/sim/app/api/mcp/workflow-servers/[id]/route.ts index 803e9879e70..718b65aa0bb 100644 --- a/apps/sim/app/api/mcp/workflow-servers/[id]/route.ts +++ b/apps/sim/app/api/mcp/workflow-servers/[id]/route.ts @@ -9,7 +9,11 @@ import { workflowMcpServerParamsSchema, } from '@/lib/api/contracts/workflow-mcp-servers' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performDeleteWorkflowMcpServer, performUpdateWorkflowMcpServer, @@ -94,7 +98,7 @@ export const PATCH = withRouteHandler( ) => { try { const { id: serverId } = workflowMcpServerParamsSchema.parse(await params) - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = updateWorkflowMcpServerBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -130,6 +134,8 @@ export const PATCH = withRouteHandler( return createMcpSuccessResponse({ server: updatedServer }) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error updating workflow MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to update workflow MCP server', 500) } diff --git a/apps/sim/app/api/mcp/workflow-servers/[id]/tools/[toolId]/route.ts b/apps/sim/app/api/mcp/workflow-servers/[id]/tools/[toolId]/route.ts index 95e54946ded..3d6c65dfe32 100644 --- a/apps/sim/app/api/mcp/workflow-servers/[id]/tools/[toolId]/route.ts +++ b/apps/sim/app/api/mcp/workflow-servers/[id]/tools/[toolId]/route.ts @@ -9,9 +9,17 @@ import { workflowMcpToolParamsSchema, } from '@/lib/api/contracts/workflow-mcp-servers' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performDeleteWorkflowMcpTool, performUpdateWorkflowMcpTool } from '@/lib/mcp/orchestration' -import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' +import { + createMcpErrorResponse, + createMcpSuccessResponse, + mcpOrchestrationStatus, +} from '@/lib/mcp/utils' const logger = createLogger('WorkflowMcpToolAPI') @@ -86,7 +94,7 @@ export const PATCH = withRouteHandler( ) => { try { const { id: serverId, toolId } = workflowMcpToolParamsSchema.parse(await params) - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = updateWorkflowMcpToolBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -109,12 +117,10 @@ export const PATCH = withRouteHandler( parameterSchema: body.parameterSchema, }) if (!result.success || !result.tool) { - const status = - result.errorCode === 'not_found' ? 404 : result.errorCode === 'validation' ? 400 : 500 return createMcpErrorResponse( new Error(result.error || 'Failed to update tool'), result.error || 'Failed to update tool', - status + mcpOrchestrationStatus(result.errorCode) ) } @@ -124,6 +130,8 @@ export const PATCH = withRouteHandler( return createMcpSuccessResponse({ tool: updatedTool }) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error updating tool:`, error) return createMcpErrorResponse(toError(error), 'Failed to update tool', 500) } @@ -158,7 +166,7 @@ export const DELETE = withRouteHandler( return createMcpErrorResponse( new Error(result.error || 'Tool not found'), result.error || 'Tool not found', - result.errorCode === 'not_found' ? 404 : 500 + mcpOrchestrationStatus(result.errorCode) ) } const deletedTool = result.tool diff --git a/apps/sim/app/api/mcp/workflow-servers/[id]/tools/route.ts b/apps/sim/app/api/mcp/workflow-servers/[id]/tools/route.ts index 4d87728cc2e..e3fa659cd14 100644 --- a/apps/sim/app/api/mcp/workflow-servers/[id]/tools/route.ts +++ b/apps/sim/app/api/mcp/workflow-servers/[id]/tools/route.ts @@ -9,7 +9,11 @@ import { workflowMcpServerParamsSchema, } from '@/lib/api/contracts/workflow-mcp-servers' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performCreateWorkflowMcpTool } from '@/lib/mcp/orchestration' import { createMcpErrorResponse, @@ -96,7 +100,7 @@ export const POST = withRouteHandler( ) => { try { const { id: serverId } = workflowMcpServerParamsSchema.parse(await params) - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = createWorkflowMcpToolBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -136,6 +140,8 @@ export const POST = withRouteHandler( return createMcpSuccessResponse({ tool }, 201) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error adding tool:`, error) return createMcpErrorResponse(toError(error), 'Failed to add tool', 500) } diff --git a/apps/sim/app/api/mcp/workflow-servers/route.ts b/apps/sim/app/api/mcp/workflow-servers/route.ts index 4356592e4c8..10398e6eeb4 100644 --- a/apps/sim/app/api/mcp/workflow-servers/route.ts +++ b/apps/sim/app/api/mcp/workflow-servers/route.ts @@ -6,7 +6,11 @@ import { and, eq, inArray, isNull, sql } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { createWorkflowMcpServerBodySchema } from '@/lib/api/contracts/workflow-mcp-servers' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' +import { + mcpBodyReadErrorResponse, + readMcpJsonBodyWithLimit, + withMcpAuth, +} from '@/lib/mcp/middleware' import { performCreateWorkflowMcpServer } from '@/lib/mcp/orchestration' import { createMcpErrorResponse, @@ -96,7 +100,7 @@ export const POST = withRouteHandler( withMcpAuth('write')( async (request: NextRequest, { userId, userName, userEmail, workspaceId, requestId }) => { try { - const rawBody = getParsedBody(request) ?? (await request.json()) + const rawBody = await readMcpJsonBodyWithLimit(request) const parsedBody = createWorkflowMcpServerBodySchema.safeParse(rawBody) if (!parsedBody.success) { @@ -138,6 +142,8 @@ export const POST = withRouteHandler( return createMcpSuccessResponse({ server, addedTools }, 201) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error(`[${requestId}] Error creating workflow MCP server:`, error) return createMcpErrorResponse(toError(error), 'Failed to create workflow MCP server', 500) } diff --git a/apps/sim/app/api/workflows/[id]/execute/route.async.test.ts b/apps/sim/app/api/workflows/[id]/execute/route.async.test.ts index 5b8debdc366..2c7b9aa7064 100644 --- a/apps/sim/app/api/workflows/[id]/execute/route.async.test.ts +++ b/apps/sim/app/api/workflows/[id]/execute/route.async.test.ts @@ -10,14 +10,21 @@ import { loggingSessionMock, requestUtilsMockFns, workflowAuthzMockFns, + workflowsPersistenceUtilsMock, + workflowsPersistenceUtilsMockFns, workflowsUtilsMock, workflowsUtilsMockFns, } from '@sim/testing' +import { NextRequest } from 'next/server' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockEnqueue } = vi.hoisted(() => ({ - mockEnqueue: vi.fn().mockResolvedValue('job-123'), -})) +const { mockEnqueue, mockExecuteWorkflowCore, mockHandlePostExecutionPauseState } = vi.hoisted( + () => ({ + mockEnqueue: vi.fn().mockResolvedValue('job-123'), + mockExecuteWorkflowCore: vi.fn(), + mockHandlePostExecutionPauseState: vi.fn(), + }) +) const mockCheckHybridAuth = hybridAuthMockFns.mockCheckHybridAuth const mockPreprocessExecution = executionPreprocessingMockFns.mockPreprocessExecution @@ -29,6 +36,26 @@ vi.mock('@/lib/workflows/utils', () => workflowsUtilsMock) vi.mock('@/lib/execution/preprocessing', () => executionPreprocessingMock) +vi.mock('@/lib/workflows/persistence/utils', () => workflowsPersistenceUtilsMock) + +vi.mock('@/lib/workflows/executor/execution-core', () => ({ + executeWorkflowCore: mockExecuteWorkflowCore, +})) + +vi.mock('@/lib/workflows/executor/pause-persistence', () => ({ + handlePostExecutionPauseState: mockHandlePostExecutionPauseState, +})) + +vi.mock('@/lib/execution/payloads/store', () => ({ + storeLargeValue: vi.fn(async (_value, _json, size: number) => ({ + __simLargeValueRef: true, + version: 1, + id: 'lv_abcdefghijkl', + kind: 'string', + size, + })), +})) + vi.mock('@/lib/core/async-jobs', () => ({ getJobQueue: vi.fn().mockResolvedValue({ enqueue: mockEnqueue, @@ -65,6 +92,7 @@ vi.mock('@sim/utils/id', () => ({ ), })) +import { storeLargeValue } from '@/lib/execution/payloads/store' import { POST } from './route' describe('workflow execute async route', () => { @@ -99,6 +127,19 @@ describe('workflow execute async route', () => { workspaceId: 'workspace-1', }, }) + workflowsPersistenceUtilsMockFns.mockLoadDeployedWorkflowState.mockResolvedValue(null) + workflowsPersistenceUtilsMockFns.mockLoadWorkflowFromNormalizedTables.mockResolvedValue(null) + mockExecuteWorkflowCore.mockResolvedValue({ + success: true, + status: 'completed', + output: { ok: true }, + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + }) + mockHandlePostExecutionPauseState.mockResolvedValue(undefined) }) it('queues async execution with matching correlation metadata', async () => { @@ -112,7 +153,7 @@ describe('workflow execute async route', () => { ) const params = Promise.resolve({ id: 'workflow-1' }) - const response = await POST(req as any, { params }) + const response = await POST(req, { params }) const body = await response.json() expect(response.status).toBe(202) @@ -143,4 +184,243 @@ describe('workflow execute async route', () => { }) ) }) + + it('rejects oversized request bodies before authorization work', async () => { + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'Content-Length': String(10 * 1024 * 1024 + 1), + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error).toContain('Workflow execution request body') + expect(mockAuthorizeWorkflowByWorkspacePermission).not.toHaveBeenCalled() + }) + + it('authenticates before rejecting oversized request bodies', async () => { + mockCheckHybridAuth.mockResolvedValueOnce({ + success: false, + error: 'Unauthorized', + authType: 'api_key', + }) + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'Content-Length': String(10 * 1024 * 1024 + 1), + 'X-API-Key': 'invalid', + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(401) + expect(body.error).toBe('Unauthorized') + expect(mockCheckHybridAuth).toHaveBeenCalled() + }) + + it('returns 499 when a non-SSE execution is cancelled by client disconnect', async () => { + const abortController = new AbortController() + mockExecuteWorkflowCore.mockImplementationOnce( + async ({ abortSignal }: { abortSignal: AbortSignal }) => { + abortController.abort() + expect(abortSignal.aborted).toBe(true) + return { + success: false, + status: 'cancelled', + output: { partial: true }, + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + } + } + ) + const req = new NextRequest('http://localhost:3000/api/workflows/workflow-1/execute', { + method: 'POST', + body: JSON.stringify({ input: { hello: 'world' } }), + signal: abortController.signal, + }) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(499) + expect(body.error).toBe('Client cancelled request') + }) + + it('rejects large MCP bridge outputs instead of returning large-value refs', async () => { + mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'internal-user-1', + authType: 'internal_jwt', + }) + mockExecuteWorkflowCore.mockResolvedValueOnce({ + success: true, + status: 'completed', + output: 'x'.repeat(10 * 1024 * 1024 + 1), + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + }) + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'X-Sim-MCP-Tool-Call': 'true', + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(413) + expect(body.error).toContain('Workflow execution response') + expect(storeLargeValue).not.toHaveBeenCalled() + }) + + it('does not trust client-spoofed MCP bridge headers on API key executions', async () => { + mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'api-user-1', + authType: 'api_key', + apiKeyType: 'personal', + }) + workflowsUtilsMockFns.mockWorkflowHasResponseBlock.mockReturnValueOnce(true) + workflowsUtilsMockFns.mockCreateHttpResponseFromBlock.mockResolvedValueOnce( + Response.json({ response: 'plain text body' }) + ) + mockExecuteWorkflowCore.mockResolvedValueOnce({ + success: true, + status: 'completed', + output: { response: 'plain text body' }, + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + }) + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'X-API-Key': 'valid', + 'X-Sim-MCP-Tool-Call': 'true', + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body).toEqual({ response: 'plain text body' }) + expect(workflowsUtilsMockFns.mockCreateHttpResponseFromBlock).toHaveBeenCalled() + }) + + it('keeps trusted internal MCP bridge executions on the JSON envelope path', async () => { + mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'internal-user-1', + authType: 'internal_jwt', + }) + workflowsUtilsMockFns.mockWorkflowHasResponseBlock.mockReturnValueOnce(true) + mockExecuteWorkflowCore.mockResolvedValueOnce({ + success: true, + status: 'completed', + output: { response: 'plain text body' }, + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + }) + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'X-Sim-MCP-Tool-Call': 'true', + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + const body = await response.json() + + expect(response.status).toBe(200) + expect(body).toMatchObject({ + success: true, + output: { response: 'plain text body' }, + }) + expect(workflowsUtilsMockFns.mockCreateHttpResponseFromBlock).not.toHaveBeenCalled() + expect(mockExecuteWorkflowCore).toHaveBeenCalledWith( + expect.objectContaining({ + snapshot: expect.objectContaining({ + input: { hello: 'world' }, + }), + }) + ) + }) + + it('preserves authenticated-user actor semantics for trusted MCP bridge calls', async () => { + mockCheckHybridAuth.mockResolvedValueOnce({ + success: true, + userId: 'api-user-1', + authType: 'internal_jwt', + }) + mockExecuteWorkflowCore.mockResolvedValueOnce({ + success: true, + status: 'completed', + output: { ok: true }, + metadata: { + duration: 100, + startTime: '2026-01-01T00:00:00Z', + endTime: '2026-01-01T00:00:01Z', + }, + }) + const req = createMockRequest( + 'POST', + { input: { hello: 'world' } }, + { + 'Content-Type': 'application/json', + 'X-Sim-MCP-Tool-Call': 'true', + 'X-Sim-MCP-Tool-Actor': 'authenticated-user', + } + ) + const params = Promise.resolve({ id: 'workflow-1' }) + + const response = await POST(req, { params }) + + expect(response.status).toBe(200) + expect(mockPreprocessExecution).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 'api-user-1', + useAuthenticatedUserAsActor: true, + }) + ) + const executionCall = mockExecuteWorkflowCore.mock.calls[0][0] + const snapshot = + typeof executionCall.snapshot === 'string' + ? JSON.parse(executionCall.snapshot) + : executionCall.snapshot + expect(snapshot.metadata.enforceCredentialAccess).toBe(true) + }) }) diff --git a/apps/sim/app/api/workflows/[id]/execute/route.ts b/apps/sim/app/api/workflows/[id]/execute/route.ts index 3ec162fc299..7a110caa13a 100644 --- a/apps/sim/app/api/workflows/[id]/execute/route.ts +++ b/apps/sim/app/api/workflows/[id]/execute/route.ts @@ -17,6 +17,12 @@ import { } from '@/lib/core/execution-limits' import { generateRequestId } from '@/lib/core/utils/request' import { SSE_HEADERS } from '@/lib/core/utils/sse' +import { + assertContentLengthWithinLimit, + isPayloadSizeLimitError, + PayloadSizeLimitError, + readStreamToBufferWithLimit, +} from '@/lib/core/utils/stream-limits' import { getBaseUrl } from '@/lib/core/utils/urls' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { @@ -36,9 +42,15 @@ import { registerManualExecutionAborter, unregisterManualExecutionAborter, } from '@/lib/execution/manual-cancellation' +import { containsLargeValueRef } from '@/lib/execution/payloads/large-value-ref' import { compactBlockLogs, compactExecutionPayload } from '@/lib/execution/payloads/serializer' import { preprocessExecution } from '@/lib/execution/preprocessing' import { LoggingSession } from '@/lib/logs/execution/logging-session' +import { + MAX_MCP_WORKFLOW_RESPONSE_BYTES, + MCP_TOOL_BRIDGE_ACTOR_HEADER, + MCP_TOOL_BRIDGE_HEADER, +} from '@/lib/mcp/constants' import { cleanupExecutionBase64Cache, hydrateUserFilesWithBase64, @@ -72,6 +84,7 @@ import { Serializer } from '@/serializer' import { CORE_TRIGGER_TYPES, type CoreTriggerType } from '@/stores/logs/filters/types' const logger = createLogger('WorkflowExecuteAPI') +const MAX_WORKFLOW_EXECUTE_BODY_BYTES = 10 * 1024 * 1024 export const runtime = 'nodejs' export const dynamic = 'force-dynamic' @@ -85,11 +98,73 @@ async function compactRoutePayload( userId?: string preserveUserFileBase64?: boolean preserveRoot?: boolean + rejectLargeValues?: boolean + rejectLargeValueLabel?: string + thresholdBytes?: number } ): Promise { return compactExecutionPayload(value, { ...context, requireDurable: true }) } +async function compactWorkflowResponseOutput( + value: T, + context: { + workspaceId?: string + workflowId?: string + executionId?: string + userId?: string + rejectLargeInlineOutput: boolean + } +): Promise { + const compacted = await compactRoutePayload(value, { + workspaceId: context.workspaceId, + workflowId: context.workflowId, + executionId: context.executionId, + userId: context.userId, + preserveUserFileBase64: true, + preserveRoot: !context.rejectLargeInlineOutput, + rejectLargeValues: context.rejectLargeInlineOutput, + rejectLargeValueLabel: 'Workflow execution response', + thresholdBytes: context.rejectLargeInlineOutput ? MAX_MCP_WORKFLOW_RESPONSE_BYTES : undefined, + }) + + if (context.rejectLargeInlineOutput && containsLargeValueRef(compacted)) { + throw new PayloadSizeLimitError({ + label: 'Workflow execution response', + maxBytes: MAX_MCP_WORKFLOW_RESPONSE_BYTES, + observedBytes: MAX_MCP_WORKFLOW_RESPONSE_BYTES + 1, + }) + } + + return compacted +} + +async function readExecuteRequestBody(req: NextRequest): Promise { + assertContentLengthWithinLimit( + req.headers, + MAX_WORKFLOW_EXECUTE_BODY_BYTES, + 'Workflow execution request body' + ) + const buffer = await readStreamToBufferWithLimit(req.body, { + maxBytes: MAX_WORKFLOW_EXECUTE_BODY_BYTES, + label: 'Workflow execution request body', + signal: req.signal, + }) + if (buffer.byteLength === 0) return {} + return JSON.parse(buffer.toString('utf-8')) +} + +function clientCancelledResponse(): NextResponse { + return NextResponse.json({ success: false, error: 'Client cancelled request' }, { status: 499 }) +} + +function payloadTooLargeResponse(message = 'Workflow execution response exceeds maximum size') { + return NextResponse.json( + { success: false, error: message, code: 'workflow_response_too_large' }, + { status: 413 } + ) +} + function resolveOutputIds( selectedOutputs: string[] | undefined, blocks: Record @@ -143,6 +218,28 @@ function resolveOutputIds( }) } +function bindRequestAbort( + requestSignal: AbortSignal, + timeoutController: ReturnType +): { isRequestAborted: () => boolean; cleanup: () => void } { + let requestAborted = false + const abortFromRequest = () => { + requestAborted = true + timeoutController.abort() + } + + if (requestSignal.aborted) { + abortFromRequest() + } else { + requestSignal.addEventListener('abort', abortFromRequest, { once: true }) + } + + return { + isRequestAborted: () => requestAborted || requestSignal.aborted, + cleanup: () => requestSignal.removeEventListener('abort', abortFromRequest), + } +} + type AsyncExecutionParams = { requestId: string workflowId: string @@ -279,6 +376,10 @@ async function handleExecutePost( try { const auth = await checkHybridAuth(req, { requireWorkflowId: false }) + const isMcpBridgeRequest = + auth.authType === AuthType.INTERNAL_JWT && req.headers.get(MCP_TOOL_BRIDGE_HEADER) === 'true' + const useMcpBridgeAuthenticatedUserAsActor = + isMcpBridgeRequest && req.headers.get(MCP_TOOL_BRIDGE_ACTOR_HEADER) === 'authenticated-user' let userId: string let isPublicApiAccess = false @@ -321,14 +422,24 @@ async function handleExecutePost( } let body: any = {} - const text = await req.text() - if (text) { - try { - body = JSON.parse(text) - } catch (error) { - reqLogger.warn('Failed to parse request body', { error: toError(error).message }) - return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 }) + try { + body = await readExecuteRequestBody(req) + } catch (error) { + if (isPayloadSizeLimitError(error)) { + reqLogger.warn('Workflow execution request body exceeded size limit', { + maxBytes: error.maxBytes, + observedBytes: error.observedBytes, + }) + return NextResponse.json( + { error: 'Workflow execution request body exceeds maximum size' }, + { status: 413 } + ) } + if (req.signal.aborted) { + return clientCancelledResponse() + } + reqLogger.warn('Failed to parse request body', { error: toError(error).message }) + return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 }) } const validation = executeWorkflowBodySchema.safeParse(body) @@ -461,10 +572,11 @@ async function handleExecutePost( // For API key and internal JWT auth, the entire body is the input (except for our control fields) // For session auth, the input is explicitly provided in the input field - const input = - isPublicApiAccess || - auth.authType === AuthType.API_KEY || - auth.authType === AuthType.INTERNAL_JWT + const input = isMcpBridgeRequest + ? validatedInput + : isPublicApiAccess || + auth.authType === AuthType.API_KEY || + auth.authType === AuthType.INTERNAL_JWT ? (() => { const { selectedOutputs, @@ -500,6 +612,10 @@ async function handleExecutePost( useDraftState || workflowStateOverride || rawRunFromBlock ) + if (req.signal.aborted) { + return clientCancelledResponse() + } + if ( isAsyncMode && (body.useDraftState !== undefined || @@ -544,7 +660,9 @@ async function handleExecutePost( // Client-side sessions and personal API keys bill/permission-check the // authenticated user, not the workspace billed account. const useAuthenticatedUserAsActor = - isClientSession || (auth.authType === AuthType.API_KEY && auth.apiKeyType === 'personal') + isClientSession || + (auth.authType === AuthType.API_KEY && auth.apiKeyType === 'personal') || + useMcpBridgeAuthenticatedUserAsActor // Authorization fetches the full workflow record and checks workspace permissions. // Run it first so we can pass the record to preprocessing (eliminates a duplicate DB query). @@ -560,6 +678,11 @@ async function handleExecutePost( ) } + if (req.signal.aborted) { + return clientCancelledResponse() + } + + // Pass the pre-fetched workflow record to skip the redundant Step 1 DB query in preprocessing. const preprocessResult = await preprocessExecution({ workflowId, userId, @@ -580,6 +703,10 @@ async function handleExecutePost( ) } + if (req.signal.aborted) { + return clientCancelledResponse() + } + const actorUserId = preprocessResult.actorUserId! const workflow = preprocessResult.workflowRecord! @@ -623,10 +750,17 @@ async function handleExecutePost( let processedInput = input try { + if (req.signal.aborted) { + return clientCancelledResponse() + } const workflowData = shouldUseDraftState ? await loadWorkflowFromNormalizedTables(workflowId) : await loadDeployedWorkflowState(workflowId, workspaceId) + if (req.signal.aborted) { + return clientCancelledResponse() + } + if (workflowData) { const deployedVariables = !shouldUseDraftState && 'variables' in workflowData @@ -731,6 +865,15 @@ async function handleExecutePost( const timeoutController = createTimeoutAbortController( preprocessResult.executionTimeout?.sync ) + const requestAbort = bindRequestAbort(req.signal, timeoutController) + const shouldRejectLargeInlineOutput = isMcpBridgeRequest + const workflowResponseCompaction = { + workspaceId, + workflowId, + executionId, + userId: actorUserId, + rejectLargeInlineOutput: shouldRejectLargeInlineOutput, + } try { const snapshot = new ExecutionSnapshot( @@ -753,14 +896,16 @@ async function handleExecutePost( }) await handlePostExecutionPauseState({ result, workflowId, executionId, loggingSession }) - const compactResultOutput = await compactRoutePayload(result.output, { - workspaceId, - workflowId, - executionId, - userId: actorUserId, - preserveUserFileBase64: true, - preserveRoot: true, - }) + + if ( + result.status === 'cancelled' && + requestAbort.isRequestAborted() && + !timeoutController.isTimedOut() + ) { + reqLogger.info('Non-SSE execution cancelled by client disconnect') + await loggingSession.markAsFailed('Client cancelled request') + return clientCancelledResponse() + } if ( result.status === 'cancelled' && @@ -772,6 +917,10 @@ async function handleExecutePost( timeoutMs: timeoutController.timeoutMs, }) await loggingSession.markAsFailed(timeoutErrorMessage) + const compactResultOutput = await compactWorkflowResponseOutput( + result.output, + workflowResponseCompaction + ) return NextResponse.json( { @@ -793,31 +942,32 @@ async function handleExecutePost( const outputLargeValueKeys = result.metadata?.largeValueKeys ?? largeValueKeys const outputFileKeys = result.metadata?.fileKeys ?? fileKeys - const outputWithBase64 = includeFileBase64 - ? ((await hydrateUserFilesWithBase64(result.output, { - requestId, - workspaceId, - workflowId, - executionId, - largeValueExecutionIds, - largeValueKeys: outputLargeValueKeys, - fileKeys: outputFileKeys, - allowLargeValueWorkflowScope, - userId: actorUserId, - maxBytes: base64MaxBytes, - preserveLargeValueMetadata: true, - })) as NormalizedBlockOutput) - : result.output + const outputWithBase64 = + includeFileBase64 && !shouldRejectLargeInlineOutput + ? ((await hydrateUserFilesWithBase64(result.output, { + requestId, + workspaceId, + workflowId, + executionId, + largeValueExecutionIds, + largeValueKeys: outputLargeValueKeys, + fileKeys: outputFileKeys, + allowLargeValueWorkflowScope, + userId: actorUserId, + maxBytes: base64MaxBytes, + preserveLargeValueMetadata: true, + })) as NormalizedBlockOutput) + : result.output - if (auth.authType !== AuthType.INTERNAL_JWT && workflowHasResponseBlock(result)) { - const compactResponseBlockOutput = await compactRoutePayload(outputWithBase64, { - workspaceId, - workflowId, - executionId, - userId: actorUserId, - preserveUserFileBase64: true, - preserveRoot: true, - }) + if ( + !isMcpBridgeRequest && + auth.authType !== AuthType.INTERNAL_JWT && + workflowHasResponseBlock(result) + ) { + const compactResponseBlockOutput = await compactWorkflowResponseOutput( + outputWithBase64, + workflowResponseCompaction + ) return await createHttpResponseFromBlock( { ...result, output: compactResponseBlockOutput }, { @@ -833,14 +983,10 @@ async function handleExecutePost( ) } - const compactOutput = await compactRoutePayload(outputWithBase64, { - workspaceId, - workflowId, - executionId, - userId: actorUserId, - preserveUserFileBase64: true, - preserveRoot: true, - }) + const compactOutput = await compactWorkflowResponseOutput( + outputWithBase64, + workflowResponseCompaction + ) const filteredResult = { success: result.success, @@ -860,21 +1006,40 @@ async function handleExecutePost( } catch (error: unknown) { const errorMessage = getErrorMessage(error, 'Unknown error') + if (requestAbort.isRequestAborted() && !timeoutController.isTimedOut()) { + reqLogger.info('Non-SSE execution aborted after client disconnect') + return clientCancelledResponse() + } + if ( + isPayloadSizeLimitError(error) && + shouldRejectLargeInlineOutput && + error.label === 'Workflow execution response' + ) { + return payloadTooLargeResponse() + } + reqLogger.error(`Non-SSE execution failed: ${errorMessage}`) const executionResult = hasExecutionResult(error) ? error.executionResult : undefined const status = getExecutionErrorStatus(error) - const compactErrorOutput = executionResult?.output - ? await compactRoutePayload(executionResult.output, { - workspaceId, - workflowId, - executionId, - userId: actorUserId, - preserveUserFileBase64: true, - preserveRoot: true, - }) - : undefined - + let compactErrorOutput: NormalizedBlockOutput | undefined + if (executionResult && Object.hasOwn(executionResult, 'output')) { + try { + compactErrorOutput = await compactWorkflowResponseOutput( + executionResult.output, + workflowResponseCompaction + ) + } catch (compactError) { + if ( + isPayloadSizeLimitError(compactError) && + shouldRejectLargeInlineOutput && + compactError.label === 'Workflow execution response' + ) { + return payloadTooLargeResponse() + } + throw compactError + } + } return NextResponse.json( { success: false, @@ -891,6 +1056,7 @@ async function handleExecutePost( { status } ) } finally { + requestAbort.cleanup() timeoutController.cleanup() if (executionId) { void cleanupExecutionBase64Cache(executionId).catch((error) => { @@ -1187,10 +1353,18 @@ async function handleExecutePost( const reader = streamingExec.stream.getReader() const decoder = new TextDecoder() + const cancelReader = () => { + void reader.cancel(timeoutController.signal.reason).catch(() => {}) + } try { + if (timeoutController.signal.aborted || isStreamClosed) return + timeoutController.signal.addEventListener('abort', cancelReader, { once: true }) + while (true) { + if (timeoutController.signal.aborted || isStreamClosed) break const { done, value } = await reader.read() + if (timeoutController.signal.aborted || isStreamClosed) break if (done) break const chunk = decoder.decode(value, { stream: true }) @@ -1203,16 +1377,21 @@ async function handleExecutePost( }) } - await sendEvent({ - type: 'stream:done', - timestamp: new Date().toISOString(), - executionId, - workflowId, - data: { blockId }, - }) + if (!timeoutController.signal.aborted && !isStreamClosed) { + await sendEvent({ + type: 'stream:done', + timestamp: new Date().toISOString(), + executionId, + workflowId, + data: { blockId }, + }) + } } catch (error) { - reqLogger.error('Error streaming block content:', error) + if (!timeoutController.signal.aborted && !isStreamClosed) { + reqLogger.error('Error streaming block content:', error) + } } finally { + timeoutController.signal.removeEventListener('abort', cancelReader) try { await reader.cancel().catch(() => {}) } catch {} @@ -1517,6 +1696,7 @@ async function handleExecutePost( }, cancel() { isStreamClosed = true + timeoutController.abort() reqLogger.info('Client disconnected from SSE stream') }, }) diff --git a/apps/sim/lib/api/contracts/mcp.ts b/apps/sim/lib/api/contracts/mcp.ts index 2a6fc9b0888..dfcddac85ad 100644 --- a/apps/sim/lib/api/contracts/mcp.ts +++ b/apps/sim/lib/api/contracts/mcp.ts @@ -2,6 +2,8 @@ import { z } from 'zod' import { type ContractJsonResponse, defineRouteContract } from '@/lib/api/contracts/types' import type { McpToolSchema, McpToolSchemaProperty } from '@/lib/mcp/types' +const MAX_MCP_REFRESH_SERVER_IDS = 100 + const dateStringSchema = z.preprocess( (value) => (value instanceof Date ? value.toISOString() : value), z.string() @@ -160,7 +162,17 @@ export const discoverMcpToolsQuerySchema = mcpWorkspaceQuerySchema.extend({ }) export const refreshMcpToolsBodySchema = z.object({ - serverIds: z.array(z.string()), + serverIds: z + .array(z.string().min(1)) + .transform((serverIds) => [...new Set(serverIds)]) + .pipe( + z + .array(z.string()) + .max( + MAX_MCP_REFRESH_SERVER_IDS, + `At most ${MAX_MCP_REFRESH_SERVER_IDS} MCP servers can be refreshed at once` + ) + ), }) export const mcpEventsQuerySchema = z.object({ diff --git a/apps/sim/lib/api/contracts/workflow-mcp-servers.ts b/apps/sim/lib/api/contracts/workflow-mcp-servers.ts index 3c51d875862..b1c18344489 100644 --- a/apps/sim/lib/api/contracts/workflow-mcp-servers.ts +++ b/apps/sim/lib/api/contracts/workflow-mcp-servers.ts @@ -1,5 +1,6 @@ import { z } from 'zod' import { defineRouteContract } from '@/lib/api/contracts/types' +import { MAX_MCP_TOOLS_PER_SERVER } from '@/lib/mcp/constants' const dateStringSchema = z.preprocess( (value) => (value instanceof Date ? value.toISOString() : value), @@ -67,7 +68,16 @@ export const createWorkflowMcpServerBodySchema = z name: z.string().min(1), description: z.string().optional(), isPublic: z.boolean().optional(), - workflowIds: z.array(z.string()).optional(), + workflowIds: z + .array(z.string()) + .max( + MAX_MCP_TOOLS_PER_SERVER, + `Workflow MCP servers can include at most ${MAX_MCP_TOOLS_PER_SERVER} tools` + ) + .refine((workflowIds) => new Set(workflowIds).size === workflowIds.length, { + message: 'workflowIds must be unique', + }) + .optional(), }) .passthrough() diff --git a/apps/sim/lib/core/utils/stream-limits.ts b/apps/sim/lib/core/utils/stream-limits.ts index 06bd7c0f650..d48fbb7083a 100644 --- a/apps/sim/lib/core/utils/stream-limits.ts +++ b/apps/sim/lib/core/utils/stream-limits.ts @@ -99,8 +99,17 @@ export async function readStreamToBufferWithLimit( const reader = stream.getReader() const chunks: Buffer[] = [] let totalBytes = 0 + const abortFromSignal = () => { + void reader.cancel(options.signal?.reason).catch(() => {}) + } try { + if (options.signal?.aborted) { + await reader.cancel(options.signal.reason).catch(() => {}) + throw toError(options.signal.reason ?? new Error('Aborted')) + } + options.signal?.addEventListener('abort', abortFromSignal, { once: true }) + while (true) { if (options.signal?.aborted) { await reader.cancel(options.signal.reason).catch(() => {}) @@ -108,6 +117,9 @@ export async function readStreamToBufferWithLimit( } const { done, value } = await reader.read() + if (options.signal?.aborted) { + throw toError(options.signal.reason ?? new Error('Aborted')) + } if (done) break if (!value) continue @@ -125,6 +137,7 @@ export async function readStreamToBufferWithLimit( chunks.push(Buffer.from(value)) } } finally { + options.signal?.removeEventListener('abort', abortFromSignal) reader.releaseLock() } diff --git a/apps/sim/lib/execution/payloads/serializer.test.ts b/apps/sim/lib/execution/payloads/serializer.test.ts index e76f500b4a1..3b77f37981d 100644 --- a/apps/sim/lib/execution/payloads/serializer.test.ts +++ b/apps/sim/lib/execution/payloads/serializer.test.ts @@ -91,6 +91,24 @@ describe('compactExecutionPayload', () => { expect(isLargeValueRef(compacted.metadata)).toBe(true) }) + it('rejects oversized values before preserving or spilling them when requested', async () => { + await expect( + compactExecutionPayload( + { root: Object.fromEntries(Array.from({ length: 100 }, (_, index) => [`k${index}`, 'x'])) }, + { + thresholdBytes: 256, + preserveRoot: true, + rejectLargeValues: true, + rejectLargeValueLabel: 'Workflow execution response', + ...TEST_EXECUTION_CONTEXT, + } + ) + ).rejects.toMatchObject({ + name: 'PayloadSizeLimitError', + label: 'Workflow execution response', + }) + }) + it('does not double-spill existing refs', async () => { const compacted = await compactExecutionPayload( { results: [[{ payload: 'x'.repeat(2048) }]] }, diff --git a/apps/sim/lib/execution/payloads/serializer.ts b/apps/sim/lib/execution/payloads/serializer.ts index 7698a7c2b5f..c3d0079defb 100644 --- a/apps/sim/lib/execution/payloads/serializer.ts +++ b/apps/sim/lib/execution/payloads/serializer.ts @@ -1,3 +1,4 @@ +import { PayloadSizeLimitError } from '@/lib/core/utils/stream-limits' import { isUserFileWithMetadata } from '@/lib/core/utils/user-file' import { createLargeArrayManifest, @@ -14,6 +15,8 @@ export interface CompactExecutionPayloadOptions extends LargeValueStoreContext { thresholdBytes?: number preserveUserFileBase64?: boolean preserveRoot?: boolean + rejectLargeValues?: boolean + rejectLargeValueLabel?: string } interface CompactState { @@ -44,6 +47,24 @@ function canPersistDurably(options: CompactExecutionPayloadOptions): boolean { return Boolean(options.workspaceId && options.workflowId && options.executionId) } +function largeValueLimitError( + options: CompactExecutionPayloadOptions, + observedBytes: number +): PayloadSizeLimitError { + return new PayloadSizeLimitError({ + label: options.rejectLargeValueLabel ?? 'Large execution value', + maxBytes: options.thresholdBytes ?? LARGE_VALUE_THRESHOLD_BYTES, + observedBytes, + }) +} + +function assertRejectSize(observedBytes: number, options: CompactExecutionPayloadOptions): void { + if (!options.rejectLargeValues) return + if (observedBytes > (options.thresholdBytes ?? LARGE_VALUE_THRESHOLD_BYTES)) { + throw largeValueLimitError(options, observedBytes) + } +} + async function compactValue( value: unknown, options: CompactExecutionPayloadOptions, @@ -53,6 +74,9 @@ async function compactValue( if (!value || typeof value !== 'object') { const measured = getJsonAndSize(value) if (measured && measured.size > (options.thresholdBytes ?? LARGE_VALUE_THRESHOLD_BYTES)) { + if (options.rejectLargeValues) { + throw largeValueLimitError(options, measured.size) + } return options.preserveRoot && depth === 0 ? value : storeLargeValue(value, measured.json, measured.size, options) @@ -67,6 +91,9 @@ async function compactValue( if (isLargeArrayManifest(value)) { const measured = getJsonAndSize(value) if (measured && measured.size > (options.thresholdBytes ?? LARGE_VALUE_THRESHOLD_BYTES)) { + if (options.rejectLargeValues) { + throw largeValueLimitError(options, measured.size) + } return storeLargeValue(value, measured.json, measured.size, options) } return value @@ -81,21 +108,14 @@ async function compactValue( } state.seen.add(value) - const compacted = Array.isArray(value) - ? await Promise.all(value.map((item) => compactValue(item, options, state, depth + 1))) - : Object.fromEntries( - await Promise.all( - Object.entries(value).map(async ([key, entryValue]) => [ - key, - key === 'finalBlockLogs' && Array.isArray(entryValue) - ? await compactBlockLogs(entryValue as BlockLog[], options) - : await compactValue(entryValue, options, state, depth + 1), - ]) - ) - ) + const compacted = await compactEntries(value, options, state, depth) const measured = getJsonAndSize(compacted) if (measured && measured.size > (options.thresholdBytes ?? LARGE_VALUE_THRESHOLD_BYTES)) { + if (options.rejectLargeValues) { + throw largeValueLimitError(options, measured.size) + } + if (Array.isArray(compacted) && (canPersistDurably(options) || options.requireDurable)) { return createLargeArrayManifest(compacted, { ...options, requireDurable: true }) } @@ -110,6 +130,76 @@ async function compactValue( return compacted } +async function compactEntries( + value: object, + options: CompactExecutionPayloadOptions, + state: CompactState, + depth: number +): Promise { + if (options.rejectLargeValues) { + return compactEntriesWithEarlyReject(value, options, state, depth) + } + + if (Array.isArray(value)) { + return Promise.all(value.map((item) => compactValue(item, options, state, depth + 1))) + } + + return Object.fromEntries( + await Promise.all( + Object.entries(value).map(async ([key, entryValue]) => [ + key, + key === 'finalBlockLogs' && Array.isArray(entryValue) + ? await compactBlockLogs(entryValue as BlockLog[], options) + : await compactValue(entryValue, options, state, depth + 1), + ]) + ) + ) +} + +async function compactEntriesWithEarlyReject( + value: object, + options: CompactExecutionPayloadOptions, + state: CompactState, + depth: number +): Promise { + if (Array.isArray(value)) { + const compacted: unknown[] = [] + let estimatedBytes = 2 + for (const item of value) { + const compactedItem = await compactValue(item, options, state, depth + 1) + compacted.push(compactedItem) + const measured = getJsonAndSize(compactedItem) + estimatedBytes += (compacted.length > 1 ? 1 : 0) + (measured?.size ?? 4) + assertRejectSize(estimatedBytes, options) + } + return compacted + } + + const compacted: Record = {} + let estimatedBytes = 2 + let serializedPropertyCount = 0 + for (const [key, entryValue] of Object.entries(value)) { + const compactedEntry = + key === 'finalBlockLogs' && Array.isArray(entryValue) + ? await compactBlockLogs(entryValue as BlockLog[], options) + : await compactValue(entryValue, options, state, depth + 1) + compacted[key] = compactedEntry + + const measured = getJsonAndSize(compactedEntry) + if (measured) { + const keyJson = JSON.stringify(key) + estimatedBytes += + (serializedPropertyCount > 0 ? 1 : 0) + + Buffer.byteLength(keyJson, 'utf8') + + 1 + + measured.size + serializedPropertyCount += 1 + assertRejectSize(estimatedBytes, options) + } + } + return compacted +} + async function forceStoreValue( value: unknown, options: CompactExecutionPayloadOptions diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index ca2b26724fa..bef88182c9e 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -32,6 +32,10 @@ import { MCP_CLIENT_CONSTANTS } from '@/lib/mcp/utils' const logger = createLogger('McpClient') +interface McpClientConnectOptions { + isCancelled?: () => boolean +} + export class McpClient { private client: Client private transport: StreamableHTTPClientTransport @@ -85,11 +89,17 @@ export class McpClient { * If an `onToolsChanged` callback was provided, registers a notification handler * for `notifications/tools/list_changed` after connecting. */ - async connect(): Promise { + async connect(options: McpClientConnectOptions = {}): Promise { logger.info(`Connecting to MCP server: ${this.config.name} (${this.config.transport})`) try { await this.client.connect(this.transport) + if (options.isCancelled?.()) { + await this.client.close().catch((error) => { + logger.warn(`Error closing cancelled connection to ${this.config.name}:`, error) + }) + throw new McpConnectionError('Connection attempt cancelled', this.config.name) + } this.isConnected = true this.connectionStatus.connected = true diff --git a/apps/sim/lib/mcp/connection-manager.test.ts b/apps/sim/lib/mcp/connection-manager.test.ts index ee880b968ba..83e32d50088 100644 --- a/apps/sim/lib/mcp/connection-manager.test.ts +++ b/apps/sim/lib/mcp/connection-manager.test.ts @@ -79,6 +79,7 @@ describe('McpConnectionManager', () => { afterEach(() => { manager?.dispose() manager = null + vi.useRealTimers() }) function createFreshManager(): McpConnectionManager { @@ -206,6 +207,36 @@ describe('McpConnectionManager', () => { expect(r2.supportsListChanged).toBe(true) expect(instances).toHaveLength(2) }) + + it('marks timed-out connect attempts as cancelled for late completions', async () => { + vi.useFakeTimers() + const deferred = createDeferred() + const instances: MockMcpClient[] = [] + + MockMcpClientConstructor.mockImplementation(() => { + const instance: MockMcpClient = { + connect: vi.fn().mockImplementation(() => deferred.promise), + disconnect: vi.fn().mockResolvedValue(undefined), + hasListChangedCapability: vi.fn().mockReturnValue(true), + onClose: vi.fn(), + } + instances.push(instance) + return instance + }) + + const mgr = createFreshManager() + const resultPromise = mgr.connect(serverConfig('server-timeout'), 'user-1', 'ws-1') + + await vi.advanceTimersByTimeAsync(15_000) + const result = await resultPromise + const connectOptions = instances[0].connect.mock.calls[0][0] + + expect(result.supportsListChanged).toBe(false) + expect(connectOptions.isCancelled()).toBe(true) + expect(instances[0].disconnect).toHaveBeenCalled() + + deferred.resolve() + }) }) describe('dispose', () => { @@ -225,4 +256,94 @@ describe('McpConnectionManager', () => { expect(result.supportsListChanged).toBe(false) }) }) + + describe('intentional disconnect cleanup', () => { + it('does not reconnect when disconnectServer closes a managed client', async () => { + vi.useFakeTimers() + let closeHandler: (() => void) | undefined + const instances: MockMcpClient[] = [] + + MockMcpClientConstructor.mockImplementation(() => { + const instance: MockMcpClient = { + connect: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockImplementation(async () => { + closeHandler?.() + }), + hasListChangedCapability: vi.fn().mockReturnValue(true), + onClose: vi.fn().mockImplementation((handler: () => void) => { + closeHandler = handler + }), + } + instances.push(instance) + return instance + }) + + const mgr = createFreshManager() + await mgr.connect(serverConfig('server-5'), 'user-1', 'ws-1') + + await mgr.disconnectServer('server-5') + await vi.advanceTimersByTimeAsync(2_000) + + expect(instances).toHaveLength(1) + expect(mgr.hasConnection('server-5')).toBe(false) + }) + + it('does not reconnect when close fires after disconnect resolves', async () => { + vi.useFakeTimers() + let closeHandler: (() => void) | undefined + const instances: MockMcpClient[] = [] + + MockMcpClientConstructor.mockImplementation(() => { + const instance: MockMcpClient = { + connect: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockResolvedValue(undefined), + hasListChangedCapability: vi.fn().mockReturnValue(true), + onClose: vi.fn().mockImplementation((handler: () => void) => { + closeHandler = handler + }), + } + instances.push(instance) + return instance + }) + + const mgr = createFreshManager() + await mgr.connect(serverConfig('server-7'), 'user-1', 'ws-1') + + await mgr.disconnectServer('server-7') + closeHandler?.() + await vi.advanceTimersByTimeAsync(2_000) + + expect(instances).toHaveLength(1) + expect(mgr.hasConnection('server-7')).toBe(false) + }) + + it('does not reconnect idle connections after cleanup disconnects them', async () => { + vi.useFakeTimers() + const closeHandlers: Array<() => void> = [] + const instances: MockMcpClient[] = [] + + MockMcpClientConstructor.mockImplementation(() => { + const instance: MockMcpClient = { + connect: vi.fn().mockResolvedValue(undefined), + disconnect: vi.fn().mockImplementation(async () => { + closeHandlers.at(-1)?.() + }), + hasListChangedCapability: vi.fn().mockReturnValue(true), + onClose: vi.fn().mockImplementation((handler: () => void) => { + closeHandlers.push(handler) + }), + } + instances.push(instance) + return instance + }) + + const mgr = createFreshManager() + await mgr.connect(serverConfig('server-6'), 'user-1', 'ws-1') + + await vi.advanceTimersByTimeAsync(35 * 60 * 1000) + + expect(instances).toHaveLength(1) + expect(mgr.hasConnection('server-6')).toBe(false) + }) + }) }) diff --git a/apps/sim/lib/mcp/connection-manager.ts b/apps/sim/lib/mcp/connection-manager.ts index 6b80f3c06a6..178a3f54564 100644 --- a/apps/sim/lib/mcp/connection-manager.ts +++ b/apps/sim/lib/mcp/connection-manager.ts @@ -11,6 +11,7 @@ */ import { createLogger } from '@sim/logger' +import { backoffWithJitter } from '@sim/utils/retry' import { isTest } from '@/lib/core/config/feature-flags' import { McpClient } from '@/lib/mcp/client' import { getOrCreateOauthRow, loadPreregisteredClient, SimMcpOauthProvider } from '@/lib/mcp/oauth' @@ -28,11 +29,37 @@ const logger = createLogger('McpConnectionManager') const MAX_CONNECTIONS = 50 const MAX_RECONNECT_ATTEMPTS = 10 const BASE_RECONNECT_DELAY_MS = 1000 +const CONNECT_TIMEOUT_MS = 15_000 const IDLE_TIMEOUT_MS = 30 * 60 * 1000 // 30 minutes const IDLE_CHECK_INTERVAL_MS = 5 * 60 * 1000 // 5 minutes type ToolsChangedListener = (event: ToolsChangedEvent) => void +async function withConnectTimeout(client: McpClient, serverName: string): Promise { + let timeoutId: ReturnType | undefined + let timedOut = false + const connectPromise = client.connect({ isCancelled: () => timedOut }) + try { + await Promise.race([ + connectPromise, + new Promise((_, reject) => { + timeoutId = setTimeout(() => { + timedOut = true + reject(new Error(`Timed out connecting to MCP server ${serverName}`)) + }, CONNECT_TIMEOUT_MS) + }), + ]) + } catch (error) { + if (timedOut) { + void connectPromise.catch(() => {}) + } + await client.disconnect().catch(() => {}) + throw error + } finally { + if (timeoutId) clearTimeout(timeoutId) + } +} + /** * Cache key for managed connections. * MCP servers are workspace-owned, so OAuth/header/no-auth connections are @@ -140,7 +167,7 @@ export class McpConnectionManager { }) try { - await client.connect() + await withConnectTimeout(client, config.name) } catch (error) { logger.error(`[${config.name}] Failed to connect for persistent monitoring:`, error) return { supportsListChanged: false } @@ -191,15 +218,17 @@ export class McpConnectionManager { const client = this.connections.get(key) if (client) { + this.connections.delete(key) + this.states.delete(key) try { await client.disconnect() } catch (error) { logger.warn(`Error disconnecting managed client ${key}:`, error) } - this.connections.delete(key) + } else { + this.states.delete(key) } - this.states.delete(key) logger.info(`Managed connection removed: ${key}`) } @@ -331,8 +360,11 @@ export class McpConnectionManager { return } - const delay = Math.min(BASE_RECONNECT_DELAY_MS * 2 ** state.reconnectAttempts, 60_000) state.reconnectAttempts++ + const delay = backoffWithJitter(state.reconnectAttempts, null, { + baseMs: BASE_RECONNECT_DELAY_MS, + maxMs: 60_000, + }) logger.info( `[${config.name}] Reconnecting in ${delay}ms (attempt ${state.reconnectAttempts}/${MAX_RECONNECT_ATTEMPTS})` diff --git a/apps/sim/lib/mcp/constants.ts b/apps/sim/lib/mcp/constants.ts new file mode 100644 index 00000000000..2b583d569ca --- /dev/null +++ b/apps/sim/lib/mcp/constants.ts @@ -0,0 +1,11 @@ +export const MAX_MCP_TOOLS_PER_SERVER = 100 +export const MAX_MCP_SERVERS_PER_WORKFLOW = 100 +export const MCP_TOOL_BRIDGE_HEADER = 'X-Sim-MCP-Tool-Call' +export const MCP_TOOL_BRIDGE_ACTOR_HEADER = 'X-Sim-MCP-Tool-Actor' +export const MAX_MCP_PARAMETER_SCHEMA_BYTES = 2 * 1024 * 1024 +export const MAX_MCP_TOOL_DESCRIPTION_BYTES = 64 * 1024 +export const MAX_MCP_TOOL_NAME_BYTES = 256 +export const MAX_MCP_TOOLS_LIST_RESPONSE_BYTES = 10 * 1024 * 1024 +export const MAX_MCP_WORKFLOW_RESPONSE_BYTES = 10 * 1024 * 1024 +export const MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES = MAX_MCP_PARAMETER_SCHEMA_BYTES +export const MAX_MCP_SERVER_TOOLS_METADATA_BYTES = MAX_MCP_TOOLS_LIST_RESPONSE_BYTES diff --git a/apps/sim/lib/mcp/middleware.ts b/apps/sim/lib/mcp/middleware.ts index 6e4aa816221..90367b3cd75 100644 --- a/apps/sim/lib/mcp/middleware.ts +++ b/apps/sim/lib/mcp/middleware.ts @@ -3,10 +3,17 @@ import { toError } from '@sim/utils/errors' import type { NextRequest, NextResponse } from 'next/server' import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid' import { generateRequestId } from '@/lib/core/utils/request' +import { + assertContentLengthWithinLimit, + isPayloadSizeLimitError, + readStreamToBufferWithLimit, +} from '@/lib/core/utils/stream-limits' import { createMcpErrorResponse } from '@/lib/mcp/utils' import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' const logger = createLogger('McpAuthMiddleware') +const MAX_MCP_MANAGEMENT_BODY_BYTES = 10 * 1024 * 1024 +const parsedBodies = new WeakMap() export type McpPermissionLevel = 'read' | 'write' | 'admin' @@ -36,6 +43,68 @@ interface AuthFailure { type AuthValidationResult = AuthResult | AuthFailure +class McpBodyReadError extends Error { + constructor( + readonly kind: 'aborted' | 'payload_too_large' | 'invalid_json', + readonly cause: unknown + ) { + super(toError(cause).message) + this.name = 'McpBodyReadError' + } +} + +export async function readMcpJsonBodyWithLimit(request: NextRequest): Promise { + const cached = parsedBodies.get(request) + if (cached !== undefined) return cached + + try { + assertContentLengthWithinLimit( + request.headers, + MAX_MCP_MANAGEMENT_BODY_BYTES, + 'MCP management request body' + ) + const buffer = await readStreamToBufferWithLimit(request.body, { + maxBytes: MAX_MCP_MANAGEMENT_BODY_BYTES, + label: 'MCP management request body', + signal: request.signal, + }) + const body = buffer.byteLength > 0 ? JSON.parse(buffer.toString('utf-8')) : {} + parsedBodies.set(request, body) + return body + } catch (error) { + if (request.signal.aborted) { + throw new McpBodyReadError('aborted', error) + } + if (isPayloadSizeLimitError(error)) { + throw new McpBodyReadError('payload_too_large', error) + } + if (error instanceof SyntaxError) { + throw new McpBodyReadError('invalid_json', error) + } + throw error + } +} + +export function mcpBodyReadErrorResponse( + error: unknown, + request?: NextRequest +): NextResponse | null { + if (!(error instanceof McpBodyReadError)) { + return null + } + if (error.kind === 'aborted' || request?.signal.aborted) { + return createMcpErrorResponse(error.cause, 'Client cancelled request', 499) + } + if (error.kind === 'payload_too_large') { + return createMcpErrorResponse( + error.cause, + 'MCP management request body exceeds maximum size', + 413 + ) + } + return createMcpErrorResponse(error.cause, 'Invalid request body', 400) +} + /** * Validates MCP authentication and authorization */ @@ -68,11 +137,17 @@ async function validateMcpAuth( try { const contentType = request.headers.get('content-type') if (contentType?.includes('application/json')) { - const body = await request.json() - workspaceId = body.workspaceId - ;(request as any)._parsedBody = body + const body = await readMcpJsonBodyWithLimit(request) + const bodyWorkspaceId = + body && typeof body === 'object' && 'workspaceId' in body + ? (body as { workspaceId?: unknown }).workspaceId + : undefined + workspaceId = typeof bodyWorkspaceId === 'string' ? bodyWorkspaceId : null } - } catch {} + } catch (error) { + const errorResponse = mcpBodyReadErrorResponse(error, request) + if (errorResponse) return { success: false, errorResponse } + } } if (!workspaceId) { @@ -190,6 +265,8 @@ export function withMcpAuth>( try { return await handler(request, (authResult as AuthResult).context, routeContext) } catch (error) { + const bodyErrorResponse = mcpBodyReadErrorResponse(error, request) + if (bodyErrorResponse) return bodyErrorResponse logger.error( `[${(authResult as AuthResult).context.requestId}] Error in MCP route handler:`, error @@ -199,10 +276,3 @@ export function withMcpAuth>( } } } - -/** - * Utility to get parsed request body - */ -export function getParsedBody(request: NextRequest): any { - return (request as any)._parsedBody -} diff --git a/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.test.ts b/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.test.ts new file mode 100644 index 00000000000..d2bbe589b79 --- /dev/null +++ b/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.test.ts @@ -0,0 +1,203 @@ +/** + * @vitest-environment node + */ +import { dbChainMock, dbChainMockFns, resetDbChainMock, schemaMock } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@sim/audit', () => ({ + AuditAction: { + MCP_SERVER_UPDATED: 'mcp_server_updated', + MCP_TOOL_UPDATED: 'mcp_tool_updated', + }, + AuditResourceType: { + MCP_SERVER: 'mcp_server', + MCP_TOOL: 'mcp_tool', + }, + recordAudit: vi.fn(), +})) +vi.mock('@sim/db', () => ({ + ...dbChainMock, + workflow: schemaMock.workflow, + workflowMcpServer: schemaMock.workflowMcpServer, + workflowMcpTool: schemaMock.workflowMcpTool, +})) +vi.mock('@sim/db/schema', () => schemaMock) +vi.mock('drizzle-orm', () => ({ + and: vi.fn(), + asc: vi.fn(), + eq: vi.fn(), + inArray: vi.fn(), + isNull: vi.fn(), + ne: vi.fn(), + sql: Object.assign(vi.fn(), { raw: vi.fn((value: string) => value) }), +})) +vi.mock('@/lib/mcp/pubsub', () => ({ mcpPubSub: undefined })) +vi.mock('@/lib/workflows/triggers/trigger-utils.server', () => ({ + hasValidStartBlock: vi.fn(), +})) +vi.mock('@/lib/mcp/workflow-mcp-sync', () => ({ + generateParameterSchemaForWorkflow: vi.fn().mockResolvedValue({ type: 'object', properties: {} }), +})) + +import { MAX_MCP_PARAMETER_SCHEMA_BYTES, MAX_MCP_TOOLS_PER_SERVER } from '@/lib/mcp/constants' +import { + performCreateWorkflowMcpServer, + performUpdateWorkflowMcpTool, +} from '@/lib/mcp/orchestration/workflow-mcp-lifecycle' +import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server' + +describe('workflow MCP lifecycle orchestration', () => { + beforeEach(() => { + vi.clearAllMocks() + resetDbChainMock() + }) + + it('rejects over-limit workflow server creation before inserting a server row', async () => { + const result = await performCreateWorkflowMcpServer({ + workspaceId: 'workspace-1', + userId: 'user-1', + name: 'Too Many Tools', + workflowIds: Array.from( + { length: MAX_MCP_TOOLS_PER_SERVER + 1 }, + (_, index) => `wf-${index}` + ), + }) + + expect(result).toMatchObject({ + success: false, + errorCode: 'validation', + }) + expect(dbChainMockFns.insert).not.toHaveBeenCalled() + }) + + it('rejects duplicate workflow IDs before inserting a server row', async () => { + const result = await performCreateWorkflowMcpServer({ + workspaceId: 'workspace-1', + userId: 'user-1', + name: 'Duplicate Tools', + workflowIds: ['wf-1', 'wf-1'], + }) + + expect(result).toMatchObject({ + success: false, + errorCode: 'validation', + }) + expect(dbChainMockFns.insert).not.toHaveBeenCalled() + }) + + it('rechecks deployed workflow state inside the create transaction', async () => { + dbChainMockFns.where.mockResolvedValueOnce([ + { + id: 'wf-1', + name: 'Workflow', + description: null, + isDeployed: true, + workspaceId: 'workspace-1', + deployedAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + }, + ]) + vi.mocked(hasValidStartBlock).mockResolvedValueOnce(true) + dbChainMockFns.for.mockResolvedValueOnce([ + { + id: 'wf-1', + name: 'Workflow', + description: null, + isDeployed: false, + workspaceId: 'workspace-1', + deployedAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + }, + ]) + + const result = await performCreateWorkflowMcpServer({ + workspaceId: 'workspace-1', + userId: 'user-1', + name: 'Server', + workflowIds: ['wf-1'], + }) + + expect(result).toMatchObject({ + success: false, + errorCode: 'validation', + }) + expect(dbChainMockFns.transaction).toHaveBeenCalled() + expect(dbChainMockFns.for).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.insert).not.toHaveBeenCalled() + }) + + it('rejects workflow MCP server fan-out above the per-workflow limit', async () => { + vi.mocked(hasValidStartBlock).mockResolvedValueOnce(true) + dbChainMockFns.where.mockResolvedValueOnce([ + { + id: 'wf-1', + name: 'Workflow', + description: null, + isDeployed: true, + workspaceId: 'workspace-1', + deployedAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + }, + ]) + dbChainMockFns.for.mockResolvedValueOnce([ + { + id: 'wf-1', + name: 'Workflow', + description: null, + isDeployed: true, + workspaceId: 'workspace-1', + deployedAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + }, + ]) + dbChainMockFns.groupBy.mockResolvedValueOnce([{ workflowId: 'wf-1', serverCount: 100 }]) + + const result = await performCreateWorkflowMcpServer({ + workspaceId: 'workspace-1', + userId: 'user-1', + name: 'Server', + workflowIds: ['wf-1'], + }) + + expect(result).toMatchObject({ + success: false, + errorCode: 'validation', + }) + expect(dbChainMockFns.insert).not.toHaveBeenCalled() + }) + + it('allows updating tool metadata when an unchanged stored schema exceeds the new cap', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ id: 'server-1' }]).mockResolvedValueOnce([ + { + id: 'tool-1', + toolName: 'tool_a', + toolDescription: null, + parameterSchemaBytes: MAX_MCP_PARAMETER_SCHEMA_BYTES + 1, + }, + ]) + dbChainMockFns.returning.mockResolvedValueOnce([ + { + id: 'tool-1', + serverId: 'server-1', + toolName: 'tool_a', + toolDescription: 'Updated description', + }, + ]) + + const result = await performUpdateWorkflowMcpTool({ + workspaceId: 'workspace-1', + userId: 'user-1', + serverId: 'server-1', + toolId: 'tool-1', + toolDescription: 'Updated description', + }) + + expect(result).toMatchObject({ + success: true, + tool: { + toolDescription: 'Updated description', + }, + }) + expect(dbChainMockFns.update).toHaveBeenCalled() + }) +}) diff --git a/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.ts b/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.ts index 4d04a98189d..917ffa72522 100644 --- a/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.ts +++ b/apps/sim/lib/mcp/orchestration/workflow-mcp-lifecycle.ts @@ -2,15 +2,54 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { db, workflow, workflowMcpServer, workflowMcpTool } from '@sim/db' import { createLogger } from '@sim/logger' import { generateId } from '@sim/utils/id' -import { and, eq, inArray, isNull } from 'drizzle-orm' +import { and, asc, eq, inArray, isNull, ne, sql } from 'drizzle-orm' +import type { DbOrTx } from '@/lib/db/types' +import { + MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES, + MAX_MCP_SERVER_TOOLS_METADATA_BYTES, + MAX_MCP_SERVERS_PER_WORKFLOW, + MAX_MCP_TOOLS_PER_SERVER, +} from '@/lib/mcp/constants' import { mcpPubSub } from '@/lib/mcp/pubsub' +import { + acquireWorkflowMcpServerLock, + isWorkflowMcpServerLockTimeout, + setWorkflowMcpTransactionLockTimeout, +} from '@/lib/mcp/server-locks' +import { + addMcpToolMetadataUsage, + addMcpToolMetadataUsageRow, + createMcpToolMetadataUsageRow, + getMcpServerToolMetadataUsageRows, + getMcpToolDescriptionForStorage, + getMcpToolMetadataSizes, + getMcpToolMetadataUsageFromRows, + type McpToolMetadataUsage, + validateMcpServerToolMetadataBudget, + validateMcpToolMetadataForStorage, +} from '@/lib/mcp/tool-limits' import { generateParameterSchemaForWorkflow } from '@/lib/mcp/workflow-mcp-sync' import { sanitizeToolName } from '@/lib/mcp/workflow-tool-schema' import { hasValidStartBlock } from '@/lib/workflows/triggers/trigger-utils.server' const logger = createLogger('WorkflowMcpOrchestration') -export type WorkflowMcpOrchestrationErrorCode = 'not_found' | 'validation' | 'internal' +export type WorkflowMcpOrchestrationErrorCode = + | 'not_found' + | 'validation' + | 'forbidden' + | 'conflict' + | 'internal' + +class WorkflowMcpExpectedError extends Error { + constructor( + message: string, + readonly errorCode: WorkflowMcpOrchestrationErrorCode + ) { + super(message) + this.name = 'WorkflowMcpExpectedError' + } +} interface ActorMetadata { actorName?: string | null @@ -77,7 +116,7 @@ export interface PerformCreateWorkflowMcpToolParams extends ActorMetadata { export interface PerformCreateWorkflowMcpToolResult { success: boolean error?: string - errorCode?: WorkflowMcpOrchestrationErrorCode | 'conflict' + errorCode?: WorkflowMcpOrchestrationErrorCode tool?: typeof workflowMcpTool.$inferSelect } @@ -112,80 +151,338 @@ export interface PerformDeleteWorkflowMcpToolResult { tool?: typeof workflowMcpTool.$inferSelect } +interface PreparedWorkflowMcpTool { + workflowId: string + toolName: string + toolDescription: string | null + parameterSchema: unknown +} + +interface WorkflowMcpToolWorkflowRecord { + id: string + name: string + description: string | null +} + +interface WorkflowMcpServerCreateWorkflowRecord extends WorkflowMcpToolWorkflowRecord { + isDeployed: boolean + workspaceId: string | null + deployedAt: Date | null + updatedAt: Date +} + +async function validateServerToolMetadataBudget( + serverId: string, + proposedTools: Array<{ + toolName: string + toolDescription: string | null + parameterSchema: unknown + }>, + tx: DbOrTx, + excludeToolId?: string +): Promise { + let usage = getMcpToolMetadataUsageFromRows( + await getMcpServerToolMetadataUsageRows(tx, serverId, excludeToolId) + ) + for (const tool of proposedTools) { + usage = addMcpToolMetadataUsage(usage, tool) + } + return validateMcpServerToolMetadataBudget(usage) +} + +function validateServerToolMetadataBudgetForUpdate( + currentUsage: McpToolMetadataUsage, + proposedUsage: McpToolMetadataUsage +): string | null { + if ( + proposedUsage.schemaBytes > MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES && + proposedUsage.schemaBytes > currentUsage.schemaBytes + ) { + return `MCP server tool schemas exceed maximum size of ${MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES} bytes` + } + if ( + proposedUsage.metadataBytes > MAX_MCP_SERVER_TOOLS_METADATA_BYTES && + proposedUsage.metadataBytes > currentUsage.metadataBytes + ) { + return `MCP server tool metadata exceeds maximum size of ${MAX_MCP_SERVER_TOOLS_METADATA_BYTES} bytes` + } + return null +} + +async function prepareWorkflowMcpTool(params: { + workflowRecord: WorkflowMcpToolWorkflowRecord + toolName?: string + toolDescription?: string | null + parameterSchema?: Record +}): Promise { + const { workflowRecord } = params + const toolName = sanitizeToolName(params.toolName?.trim() || workflowRecord.name) + const toolDescription = + params.toolDescription !== undefined + ? params.toolDescription?.trim() || `Execute ${workflowRecord.name} workflow` + : getMcpToolDescriptionForStorage(workflowRecord.description, workflowRecord.name) + const parameterSchema = + params.parameterSchema && Object.keys(params.parameterSchema).length > 0 + ? params.parameterSchema + : await generateParameterSchemaForWorkflow(workflowRecord.id) + const metadataLimitError = validateMcpToolMetadataForStorage({ + toolName, + toolDescription, + parameterSchema, + }) + if (metadataLimitError) { + throw new WorkflowMcpExpectedError(metadataLimitError, 'validation') + } + + return { + workflowId: workflowRecord.id, + toolName, + toolDescription, + parameterSchema, + } +} + +function sameNullableDate(left: Date | null, right: Date | null): boolean { + if (left === null || right === null) return left === right + return left.getTime() === right.getTime() +} + +function validateWorkflowForMcpServerCreate( + workflowRecord: WorkflowMcpServerCreateWorkflowRecord, + workspaceId: string +): void { + if (workflowRecord.workspaceId !== workspaceId) { + throw new WorkflowMcpExpectedError( + `Workflow is outside this workspace: ${workflowRecord.id}`, + 'forbidden' + ) + } + if (!workflowRecord.isDeployed) { + throw new WorkflowMcpExpectedError( + `Workflow must be deployed before adding as an MCP tool: ${workflowRecord.id}`, + 'validation' + ) + } +} + +function assertWorkflowMcpServerCreateSnapshotCurrent( + preparedWorkflow: WorkflowMcpServerCreateWorkflowRecord, + lockedWorkflow: WorkflowMcpServerCreateWorkflowRecord +): void { + if ( + preparedWorkflow.name !== lockedWorkflow.name || + preparedWorkflow.description !== lockedWorkflow.description || + !sameNullableDate(preparedWorkflow.deployedAt, lockedWorkflow.deployedAt) || + preparedWorkflow.updatedAt.getTime() !== lockedWorkflow.updatedAt.getTime() + ) { + throw new WorkflowMcpExpectedError( + `Workflow changed while creating MCP server, retry shortly: ${preparedWorkflow.id}`, + 'conflict' + ) + } +} + +async function validateWorkflowMcpServerMembershipBudget( + tx: DbOrTx, + workflowIds: string[] +): Promise { + if (workflowIds.length === 0) return null + + const rows = await tx + .select({ + workflowId: workflowMcpTool.workflowId, + serverCount: sql`count(distinct ${workflowMcpTool.serverId})`, + }) + .from(workflowMcpTool) + .where( + and(inArray(workflowMcpTool.workflowId, workflowIds), isNull(workflowMcpTool.archivedAt)) + ) + .groupBy(workflowMcpTool.workflowId) + + for (const row of rows) { + if ((Number(row.serverCount) || 0) >= MAX_MCP_SERVERS_PER_WORKFLOW) { + return `Workflow can be exposed on at most ${MAX_MCP_SERVERS_PER_WORKFLOW} MCP servers: ${row.workflowId}` + } + } + + return null +} + export async function performCreateWorkflowMcpServer( params: PerformCreateWorkflowMcpServerParams ): Promise { try { const name = params.name.trim() - const serverId = generateId() - const [server] = await db - .insert(workflowMcpServer) - .values({ - id: serverId, - workspaceId: params.workspaceId, - createdBy: params.userId, - name, - description: params.description?.trim() || null, - isPublic: params.isPublic ?? false, - createdAt: new Date(), - updatedAt: new Date(), - }) - .returning() - - const addedTools: Array<{ workflowId: string; toolName: string }> = [] const workflowIds = params.workflowIds || [] + if (workflowIds.length > MAX_MCP_TOOLS_PER_SERVER) { + return { + success: false, + error: `Workflow MCP servers can include at most ${MAX_MCP_TOOLS_PER_SERVER} tools`, + errorCode: 'validation', + } + } + if (new Set(workflowIds).size !== workflowIds.length) { + return { + success: false, + error: 'Workflow MCP server workflowIds must be unique', + errorCode: 'validation', + } + } + + const preparedTools: PreparedWorkflowMcpTool[] = [] + const preparedToolNames = new Set() + const preparedWorkflows = new Map() + let totalUsage = { schemaBytes: 0, metadataBytes: 0 } if (workflowIds.length > 0) { - const workflows = await db + const workflowRecords = await db .select({ id: workflow.id, name: workflow.name, description: workflow.description, isDeployed: workflow.isDeployed, workspaceId: workflow.workspaceId, + deployedAt: workflow.deployedAt, + updatedAt: workflow.updatedAt, }) .from(workflow) .where(and(inArray(workflow.id, workflowIds), isNull(workflow.archivedAt))) - for (const workflowRecord of workflows) { - if (workflowRecord.workspaceId !== params.workspaceId) { - logger.warn('Skipping workflow MCP tool outside workspace', { - workflowId: workflowRecord.id, - workspaceId: params.workspaceId, - }) - continue - } - if (!workflowRecord.isDeployed) { - logger.warn('Skipping undeployed workflow MCP tool', { workflowId: workflowRecord.id }) - continue + const workflowsById = new Map( + workflowRecords.map((workflowRecord) => [workflowRecord.id, workflowRecord]) + ) + + for (const workflowId of workflowIds) { + const workflowRecord = workflowsById.get(workflowId) + if (!workflowRecord) { + return { + success: false, + error: `Workflow not found or archived: ${workflowId}`, + errorCode: 'validation', + } } + + validateWorkflowForMcpServerCreate(workflowRecord, params.workspaceId) + const hasStartBlock = await hasValidStartBlock(workflowRecord.id) if (!hasStartBlock) { - logger.warn('Skipping workflow MCP tool without start block', { - workflowId: workflowRecord.id, - }) - continue + return { + success: false, + error: `Workflow must have a valid start block before adding as an MCP tool: ${workflowRecord.id}`, + errorCode: 'validation', + } } - const toolName = sanitizeToolName(workflowRecord.name) - const parameterSchema = await generateParameterSchemaForWorkflow(workflowRecord.id) - await db.insert(workflowMcpTool).values({ - id: generateId(), - serverId, - workflowId: workflowRecord.id, + const preparedTool = await prepareWorkflowMcpTool({ workflowRecord }) + const { toolName, toolDescription, parameterSchema } = preparedTool + if (preparedToolNames.has(toolName)) { + return { + success: false, + error: `Duplicate MCP tool name after sanitization: ${toolName}`, + errorCode: 'validation', + } + } + preparedToolNames.add(toolName) + totalUsage = addMcpToolMetadataUsage(totalUsage, { toolName, - toolDescription: workflowRecord.description || `Execute ${workflowRecord.name} workflow`, + toolDescription, parameterSchema, + }) + const budgetError = validateMcpServerToolMetadataBudget(totalUsage) + if (budgetError) { + return { success: false, error: budgetError, errorCode: 'validation' } + } + + preparedTools.push(preparedTool) + preparedWorkflows.set(workflowRecord.id, workflowRecord) + } + } + + const { server, addedTools, serverId } = await db.transaction(async (tx) => { + await setWorkflowMcpTransactionLockTimeout(tx) + + if (workflowIds.length > 0) { + const lockedWorkflows = await tx + .select({ + id: workflow.id, + name: workflow.name, + description: workflow.description, + isDeployed: workflow.isDeployed, + workspaceId: workflow.workspaceId, + deployedAt: workflow.deployedAt, + updatedAt: workflow.updatedAt, + }) + .from(workflow) + .where(and(inArray(workflow.id, workflowIds), isNull(workflow.archivedAt))) + .orderBy(asc(workflow.id)) + .for('update') + + const lockedWorkflowsById = new Map( + lockedWorkflows.map((workflowRecord) => [workflowRecord.id, workflowRecord]) + ) + + for (const workflowId of workflowIds) { + const lockedWorkflow = lockedWorkflowsById.get(workflowId) + if (!lockedWorkflow) { + throw new WorkflowMcpExpectedError( + `Workflow not found or archived: ${workflowId}`, + 'validation' + ) + } + + validateWorkflowForMcpServerCreate(lockedWorkflow, params.workspaceId) + const preparedWorkflow = preparedWorkflows.get(workflowId) + if (!preparedWorkflow) { + throw new WorkflowMcpExpectedError( + `Workflow not found or archived: ${workflowId}`, + 'validation' + ) + } + assertWorkflowMcpServerCreateSnapshotCurrent(preparedWorkflow, lockedWorkflow) + } + } + + const membershipBudgetError = await validateWorkflowMcpServerMembershipBudget(tx, workflowIds) + if (membershipBudgetError) { + throw new WorkflowMcpExpectedError(membershipBudgetError, 'validation') + } + + const newServerId = generateId() + const [createdServer] = await tx + .insert(workflowMcpServer) + .values({ + id: newServerId, + workspaceId: params.workspaceId, + createdBy: params.userId, + name, + description: params.description?.trim() || null, + isPublic: params.isPublic ?? false, createdAt: new Date(), updatedAt: new Date(), }) + .returning() - addedTools.push({ workflowId: workflowRecord.id, toolName }) - } + const insertedTools: Array<{ workflowId: string; toolName: string }> = [] + for (const preparedTool of preparedTools) { + await tx.insert(workflowMcpTool).values({ + id: generateId(), + serverId: newServerId, + workflowId: preparedTool.workflowId, + toolName: preparedTool.toolName, + toolDescription: preparedTool.toolDescription, + parameterSchema: preparedTool.parameterSchema, + createdAt: new Date(), + updatedAt: new Date(), + }) - if (addedTools.length > 0) { - mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId: params.workspaceId }) + insertedTools.push({ workflowId: preparedTool.workflowId, toolName: preparedTool.toolName }) } + + return { server: createdServer, addedTools: insertedTools, serverId: newServerId } + }) + + if (addedTools.length > 0) { + mcpPubSub?.publishWorkflowToolsChanged({ serverId, workspaceId: params.workspaceId }) } recordAudit({ @@ -209,6 +506,16 @@ export async function performCreateWorkflowMcpServer( return { success: true, server, addedTools } } catch (error) { + if (error instanceof WorkflowMcpExpectedError) { + return { success: false, error: error.message, errorCode: error.errorCode } + } + if (isWorkflowMcpServerLockTimeout(error)) { + return { + success: false, + error: 'Workflow MCP server is busy, retry shortly', + errorCode: 'conflict', + } + } logger.error('Failed to create workflow MCP server', { error }) return { success: false, error: 'Failed to create workflow MCP server', errorCode: 'internal' } } @@ -270,15 +577,21 @@ export async function performDeleteWorkflowMcpServer( params: PerformDeleteWorkflowMcpServerParams ): Promise { try { - const [server] = await db - .delete(workflowMcpServer) - .where( - and( - eq(workflowMcpServer.id, params.serverId), - eq(workflowMcpServer.workspaceId, params.workspaceId) + const server = await db.transaction(async (tx) => { + await acquireWorkflowMcpServerLock(tx, params.serverId) + + const [deletedServer] = await tx + .delete(workflowMcpServer) + .where( + and( + eq(workflowMcpServer.id, params.serverId), + eq(workflowMcpServer.workspaceId, params.workspaceId) + ) ) - ) - .returning() + .returning() + + return deletedServer + }) if (!server) { return { success: false, error: 'Server not found', errorCode: 'not_found' } @@ -304,6 +617,13 @@ export async function performDeleteWorkflowMcpServer( return { success: true, server } } catch (error) { + if (isWorkflowMcpServerLockTimeout(error)) { + return { + success: false, + error: 'Workflow MCP server is busy, retry shortly', + errorCode: 'conflict', + } + } logger.error('Failed to delete workflow MCP server', { error }) return { success: false, error: 'Failed to delete workflow MCP server', errorCode: 'internal' } } @@ -366,50 +686,136 @@ export async function performCreateWorkflowMcpTool( } } - const [existingTool] = await db - .select({ id: workflowMcpTool.id }) - .from(workflowMcpTool) - .where( - and( - eq(workflowMcpTool.serverId, params.serverId), - eq(workflowMcpTool.workflowId, params.workflowId), - isNull(workflowMcpTool.archivedAt) + const preparedTool = await prepareWorkflowMcpTool({ + workflowRecord, + toolName: params.toolName, + toolDescription: params.toolDescription, + parameterSchema: params.parameterSchema, + }) + const { toolName, toolDescription, parameterSchema } = preparedTool + + const toolId = generateId() + const tool = await db.transaction(async (tx) => { + await setWorkflowMcpTransactionLockTimeout(tx) + + const [lockedWorkflow] = await tx + .select({ + id: workflow.id, + isDeployed: workflow.isDeployed, + workspaceId: workflow.workspaceId, + }) + .from(workflow) + .where(and(eq(workflow.id, params.workflowId), isNull(workflow.archivedAt))) + .for('update') + .limit(1) + + if (!lockedWorkflow) { + throw new WorkflowMcpExpectedError('Workflow not found', 'not_found') + } + if (lockedWorkflow.workspaceId !== params.workspaceId) { + throw new WorkflowMcpExpectedError( + 'Workflow does not belong to this workspace', + 'validation' ) - ) - .limit(1) + } + if (!lockedWorkflow.isDeployed) { + throw new WorkflowMcpExpectedError( + 'Workflow must be deployed before adding as a tool', + 'validation' + ) + } - if (existingTool) { - return { - success: false, - error: 'This workflow is already added as a tool to this server', - errorCode: 'conflict', + await acquireWorkflowMcpServerLock(tx, params.serverId) + + const existingTools = await tx + .select({ id: workflowMcpTool.id }) + .from(workflowMcpTool) + .where( + and(eq(workflowMcpTool.serverId, params.serverId), isNull(workflowMcpTool.archivedAt)) + ) + .limit(MAX_MCP_TOOLS_PER_SERVER) + + if (existingTools.length >= MAX_MCP_TOOLS_PER_SERVER) { + throw new WorkflowMcpExpectedError( + `Workflow MCP servers can include at most ${MAX_MCP_TOOLS_PER_SERVER} tools`, + 'validation' + ) } - } - const toolName = sanitizeToolName(params.toolName?.trim() || workflowRecord.name) - const toolDescription = - params.toolDescription?.trim() || - workflowRecord.description || - `Execute ${workflowRecord.name} workflow` - const parameterSchema = - params.parameterSchema && Object.keys(params.parameterSchema).length > 0 - ? params.parameterSchema - : await generateParameterSchemaForWorkflow(params.workflowId) + const [existingTool] = await tx + .select({ id: workflowMcpTool.id }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.serverId, params.serverId), + eq(workflowMcpTool.workflowId, params.workflowId), + isNull(workflowMcpTool.archivedAt) + ) + ) + .limit(1) - const toolId = generateId() - const [tool] = await db - .insert(workflowMcpTool) - .values({ - id: toolId, - serverId: params.serverId, - workflowId: params.workflowId, - toolName, - toolDescription, - parameterSchema, - createdAt: new Date(), - updatedAt: new Date(), - }) - .returning() + if (existingTool) { + throw new WorkflowMcpExpectedError( + 'This workflow is already added as a tool to this server', + 'conflict' + ) + } + + const [nameCollision] = await tx + .select({ id: workflowMcpTool.id }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.serverId, params.serverId), + eq(workflowMcpTool.toolName, toolName), + isNull(workflowMcpTool.archivedAt) + ) + ) + .limit(1) + + if (nameCollision) { + throw new WorkflowMcpExpectedError( + `MCP tool name already exists on this server: ${toolName}`, + 'conflict' + ) + } + + const membershipBudgetError = await validateWorkflowMcpServerMembershipBudget(tx, [ + params.workflowId, + ]) + if (membershipBudgetError) { + throw new WorkflowMcpExpectedError(membershipBudgetError, 'validation') + } + + const budgetError = await validateServerToolMetadataBudget( + params.serverId, + [{ toolName, toolDescription, parameterSchema }], + tx + ) + if (budgetError) { + throw new WorkflowMcpExpectedError(budgetError, 'validation') + } + + const [createdTool] = await tx + .insert(workflowMcpTool) + .values({ + id: toolId, + serverId: params.serverId, + workflowId: params.workflowId, + toolName, + toolDescription, + parameterSchema, + createdAt: new Date(), + updatedAt: new Date(), + }) + .returning() + + return createdTool + }) + + if (!tool) { + return { success: false, error: 'Failed to add tool', errorCode: 'internal' } + } mcpPubSub?.publishWorkflowToolsChanged({ serverId: params.serverId, @@ -436,6 +842,16 @@ export async function performCreateWorkflowMcpTool( return { success: true, tool } } catch (error) { + if (error instanceof WorkflowMcpExpectedError) { + return { success: false, error: error.message, errorCode: error.errorCode } + } + if (isWorkflowMcpServerLockTimeout(error)) { + return { + success: false, + error: 'Workflow MCP server is busy, retry shortly', + errorCode: 'conflict', + } + } logger.error('Failed to create workflow MCP tool', { error }) return { success: false, error: 'Failed to add tool', errorCode: 'internal' } } @@ -465,20 +881,120 @@ export async function performUpdateWorkflowMcpTool( updateData.toolDescription = params.toolDescription?.trim() || null } if (params.parameterSchema !== undefined) updateData.parameterSchema = params.parameterSchema - const updatedFields = Object.keys(updateData).filter((key) => key !== 'updatedAt') - const [tool] = await db - .update(workflowMcpTool) - .set(updateData) - .where( - and( - eq(workflowMcpTool.id, params.toolId), - eq(workflowMcpTool.serverId, params.serverId), - isNull(workflowMcpTool.archivedAt) + const tool = await db.transaction(async (tx) => { + await acquireWorkflowMcpServerLock(tx, params.serverId) + + const [currentTool] = await tx + .select({ + id: workflowMcpTool.id, + toolName: workflowMcpTool.toolName, + toolDescription: workflowMcpTool.toolDescription, + parameterSchemaBytes: sql`octet_length(${workflowMcpTool.parameterSchema}::text)`, + }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.id, params.toolId), + eq(workflowMcpTool.serverId, params.serverId), + isNull(workflowMcpTool.archivedAt) + ) ) + .limit(1) + + if (!currentTool) { + throw new WorkflowMcpExpectedError('Tool not found', 'not_found') + } + + const effectiveToolName = updateData.toolName ?? currentTool.toolName + const effectiveToolDescription = + updateData.toolDescription !== undefined + ? updateData.toolDescription + : currentTool.toolDescription + const effectiveParameterSchema = + updateData.parameterSchema !== undefined ? updateData.parameterSchema : undefined + const metadataLimitError = validateMcpToolMetadataForStorage({ + toolName: effectiveToolName, + toolDescription: effectiveToolDescription, + ...(effectiveParameterSchema !== undefined && { + parameterSchema: effectiveParameterSchema, + }), + }) + if (metadataLimitError) { + throw new WorkflowMcpExpectedError(metadataLimitError, 'validation') + } + + if (params.toolName !== undefined && effectiveToolName !== currentTool.toolName) { + const [nameCollision] = await tx + .select({ id: workflowMcpTool.id }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.serverId, params.serverId), + eq(workflowMcpTool.toolName, effectiveToolName), + ne(workflowMcpTool.id, params.toolId), + isNull(workflowMcpTool.archivedAt) + ) + ) + .limit(1) + + if (nameCollision) { + throw new WorkflowMcpExpectedError( + `MCP tool name already exists on this server: ${effectiveToolName}`, + 'conflict' + ) + } + } + + const baseUsage = getMcpToolMetadataUsageFromRows( + await getMcpServerToolMetadataUsageRows(tx, params.serverId, params.toolId) ) - .returning() + const currentUsage = addMcpToolMetadataUsageRow(baseUsage, { + id: currentTool.id, + ...getMcpToolMetadataSizes({ + toolName: currentTool.toolName, + toolDescription: currentTool.toolDescription, + }), + parameterSchemaBytes: Number(currentTool.parameterSchemaBytes) || 0, + }) + const proposedUsage = addMcpToolMetadataUsageRow( + baseUsage, + effectiveParameterSchema !== undefined + ? createMcpToolMetadataUsageRow({ + id: currentTool.id, + toolName: effectiveToolName, + toolDescription: effectiveToolDescription, + parameterSchema: effectiveParameterSchema, + }) + : { + id: currentTool.id, + ...getMcpToolMetadataSizes({ + toolName: effectiveToolName, + toolDescription: effectiveToolDescription, + }), + parameterSchemaBytes: Number(currentTool.parameterSchemaBytes) || 0, + } + ) + const budgetError = validateServerToolMetadataBudgetForUpdate(currentUsage, proposedUsage) + if (budgetError) { + throw new WorkflowMcpExpectedError(budgetError, 'validation') + } + + const [updatedTool] = await tx + .update(workflowMcpTool) + .set(updateData) + .where( + and( + eq(workflowMcpTool.id, params.toolId), + eq(workflowMcpTool.serverId, params.serverId), + isNull(workflowMcpTool.archivedAt) + ) + ) + .returning() + + return updatedTool + }) if (!tool) return { success: false, error: 'Tool not found', errorCode: 'not_found' } @@ -506,6 +1022,16 @@ export async function performUpdateWorkflowMcpTool( return { success: true, tool } } catch (error) { + if (error instanceof WorkflowMcpExpectedError) { + return { success: false, error: error.message, errorCode: error.errorCode } + } + if (isWorkflowMcpServerLockTimeout(error)) { + return { + success: false, + error: 'Workflow MCP server is busy, retry shortly', + errorCode: 'conflict', + } + } logger.error('Failed to update workflow MCP tool', { error }) return { success: false, error: 'Failed to update tool', errorCode: 'internal' } } @@ -529,12 +1055,18 @@ export async function performDeleteWorkflowMcpTool( if (!server) return { success: false, error: 'Server not found', errorCode: 'not_found' } - const [tool] = await db - .delete(workflowMcpTool) - .where( - and(eq(workflowMcpTool.id, params.toolId), eq(workflowMcpTool.serverId, params.serverId)) - ) - .returning() + const tool = await db.transaction(async (tx) => { + await acquireWorkflowMcpServerLock(tx, params.serverId) + + const [deletedTool] = await tx + .delete(workflowMcpTool) + .where( + and(eq(workflowMcpTool.id, params.toolId), eq(workflowMcpTool.serverId, params.serverId)) + ) + .returning() + + return deletedTool + }) if (!tool) return { success: false, error: 'Tool not found', errorCode: 'not_found' } @@ -557,6 +1089,13 @@ export async function performDeleteWorkflowMcpTool( return { success: true, tool } } catch (error) { + if (isWorkflowMcpServerLockTimeout(error)) { + return { + success: false, + error: 'Workflow MCP server is busy, retry shortly', + errorCode: 'conflict', + } + } logger.error('Failed to delete workflow MCP tool', { error }) return { success: false, error: 'Failed to remove tool', errorCode: 'internal' } } diff --git a/apps/sim/lib/mcp/server-locks.test.ts b/apps/sim/lib/mcp/server-locks.test.ts new file mode 100644 index 00000000000..8b9572a2fe8 --- /dev/null +++ b/apps/sim/lib/mcp/server-locks.test.ts @@ -0,0 +1,15 @@ +/** + * @vitest-environment node + */ +import { describe, expect, it } from 'vitest' +import { isWorkflowMcpServerLockTimeout } from '@/lib/mcp/server-locks' + +describe('MCP server locks', () => { + it('detects Postgres lock timeout errors', () => { + const error = Object.assign(new Error('canceling statement due to lock timeout'), { + code: '55P03', + }) + + expect(isWorkflowMcpServerLockTimeout(error)).toBe(true) + }) +}) diff --git a/apps/sim/lib/mcp/server-locks.ts b/apps/sim/lib/mcp/server-locks.ts new file mode 100644 index 00000000000..02811294699 --- /dev/null +++ b/apps/sim/lib/mcp/server-locks.ts @@ -0,0 +1,21 @@ +import { getPostgresErrorCode } from '@sim/utils/errors' +import { sql } from 'drizzle-orm' +import type { DbOrTx } from '@/lib/db/types' + +const MCP_SERVER_LOCK_TIMEOUT_MS = 3_000 +const LOCK_NOT_AVAILABLE_SQLSTATE = '55P03' + +export async function setWorkflowMcpTransactionLockTimeout(tx: DbOrTx): Promise { + await tx.execute( + sql`select set_config('lock_timeout', ${`${MCP_SERVER_LOCK_TIMEOUT_MS}ms`}, true)` + ) +} + +export async function acquireWorkflowMcpServerLock(tx: DbOrTx, serverId: string): Promise { + await setWorkflowMcpTransactionLockTimeout(tx) + await tx.execute(sql`select pg_advisory_xact_lock(hashtextextended(${serverId}, 0))`) +} + +export function isWorkflowMcpServerLockTimeout(error: unknown): boolean { + return getPostgresErrorCode(error) === LOCK_NOT_AVAILABLE_SQLSTATE +} diff --git a/apps/sim/lib/mcp/tool-limits.ts b/apps/sim/lib/mcp/tool-limits.ts new file mode 100644 index 00000000000..3f95d69d580 --- /dev/null +++ b/apps/sim/lib/mcp/tool-limits.ts @@ -0,0 +1,196 @@ +import { workflowMcpTool } from '@sim/db' +import { and, eq, isNull, ne, sql } from 'drizzle-orm' +import type { DbOrTx } from '@/lib/db/types' +import { + MAX_MCP_PARAMETER_SCHEMA_BYTES, + MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES, + MAX_MCP_SERVER_TOOLS_METADATA_BYTES, + MAX_MCP_TOOL_DESCRIPTION_BYTES, + MAX_MCP_TOOL_NAME_BYTES, +} from '@/lib/mcp/constants' + +function utf8Size(value: string): number { + return Buffer.byteLength(value, 'utf-8') +} + +function jsonSize(value: unknown): number | null { + try { + const json = JSON.stringify(value) + return typeof json === 'string' ? utf8Size(json) : null + } catch { + return null + } +} + +export interface McpToolMetadataSizes { + toolNameBytes: number + toolDescriptionBytes: number + parameterSchemaBytes: number +} + +export interface McpToolMetadataUsage { + schemaBytes: number + metadataBytes: number +} + +export interface McpToolMetadataUsageRow extends McpToolMetadataSizes { + id: string +} + +export function getMcpToolMetadataSizes(metadata: { + toolName?: string | null + toolDescription?: string | null + parameterSchema?: unknown +}): McpToolMetadataSizes { + return { + toolNameBytes: metadata.toolName ? utf8Size(metadata.toolName) : 0, + toolDescriptionBytes: metadata.toolDescription ? utf8Size(metadata.toolDescription) : 0, + parameterSchemaBytes: + metadata.parameterSchema !== undefined + ? (jsonSize(metadata.parameterSchema) ?? MAX_MCP_PARAMETER_SCHEMA_BYTES + 1) + : 0, + } +} + +export function addMcpToolMetadataUsage( + usage: McpToolMetadataUsage, + tool: { + toolName?: string | null + toolDescription?: string | null + parameterSchema?: unknown + } +): McpToolMetadataUsage { + const sizes = getMcpToolMetadataSizes(tool) + return { + schemaBytes: usage.schemaBytes + sizes.parameterSchemaBytes, + metadataBytes: + usage.metadataBytes + + sizes.toolNameBytes + + sizes.toolDescriptionBytes + + sizes.parameterSchemaBytes, + } +} + +export function addMcpToolMetadataUsageRow( + usage: McpToolMetadataUsage, + row: McpToolMetadataUsageRow +): McpToolMetadataUsage { + return { + schemaBytes: usage.schemaBytes + row.parameterSchemaBytes, + metadataBytes: + usage.metadataBytes + row.toolNameBytes + row.toolDescriptionBytes + row.parameterSchemaBytes, + } +} + +export function subtractMcpToolMetadataUsageRow( + usage: McpToolMetadataUsage, + row?: McpToolMetadataUsageRow +): McpToolMetadataUsage { + if (!row) return usage + return { + schemaBytes: usage.schemaBytes - row.parameterSchemaBytes, + metadataBytes: + usage.metadataBytes - row.toolNameBytes - row.toolDescriptionBytes - row.parameterSchemaBytes, + } +} + +export function getMcpToolMetadataUsageFromRows( + rows: McpToolMetadataUsageRow[] +): McpToolMetadataUsage { + return rows.reduce(addMcpToolMetadataUsageRow, { schemaBytes: 0, metadataBytes: 0 }) +} + +export function createMcpToolMetadataUsageRow(tool: { + id: string + toolName: string + toolDescription: string | null + parameterSchema: unknown +}): McpToolMetadataUsageRow { + return { id: tool.id, ...getMcpToolMetadataSizes(tool) } +} + +export function validateMcpServerToolMetadataBudget(usage: McpToolMetadataUsage): string | null { + if (usage.schemaBytes > MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES) { + return `MCP server tool schemas exceed maximum size of ${MAX_MCP_SERVER_PARAMETER_SCHEMAS_BYTES} bytes` + } + if (usage.metadataBytes > MAX_MCP_SERVER_TOOLS_METADATA_BYTES) { + return `MCP server tool metadata exceeds maximum size of ${MAX_MCP_SERVER_TOOLS_METADATA_BYTES} bytes` + } + return null +} + +export function exceedsMcpServerToolMetadataBudget( + usage: McpToolMetadataUsage, + tool: { toolName: string; toolDescription: string | null; parameterSchema: unknown } +): boolean { + return validateMcpServerToolMetadataBudget(addMcpToolMetadataUsage(usage, tool)) !== null +} + +export async function getMcpServerToolMetadataUsageRows( + tx: DbOrTx, + serverId: string, + excludeToolId?: string +): Promise { + const rows = await tx + .select({ + id: workflowMcpTool.id, + toolNameBytes: sql`octet_length(${workflowMcpTool.toolName})`, + toolDescriptionBytes: sql`coalesce(octet_length(${workflowMcpTool.toolDescription}), 0)`, + parameterSchemaBytes: sql`octet_length(${workflowMcpTool.parameterSchema}::text)`, + }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.serverId, serverId), + isNull(workflowMcpTool.archivedAt), + excludeToolId ? ne(workflowMcpTool.id, excludeToolId) : undefined + ) + ) + + return rows.map((row) => ({ + id: row.id, + toolNameBytes: Number(row.toolNameBytes) || 0, + toolDescriptionBytes: Number(row.toolDescriptionBytes) || 0, + parameterSchemaBytes: Number(row.parameterSchemaBytes) || 0, + })) +} + +export function getMcpToolDescriptionForStorage( + description: string | null | undefined, + workflowName: string +): string { + const trimmed = description?.trim() + if (trimmed && utf8Size(trimmed) <= MAX_MCP_TOOL_DESCRIPTION_BYTES) { + return trimmed + } + return `Execute ${workflowName} workflow` +} + +export function validateMcpToolMetadataForStorage(metadata: { + toolName?: string | null + toolDescription?: string | null + parameterSchema?: unknown +}): string | null { + if (metadata.toolName && utf8Size(metadata.toolName) > MAX_MCP_TOOL_NAME_BYTES) { + return `Tool name exceeds maximum size of ${MAX_MCP_TOOL_NAME_BYTES} bytes` + } + + if ( + metadata.toolDescription && + utf8Size(metadata.toolDescription) > MAX_MCP_TOOL_DESCRIPTION_BYTES + ) { + return `Tool description exceeds maximum size of ${MAX_MCP_TOOL_DESCRIPTION_BYTES} bytes` + } + + if (metadata.parameterSchema !== undefined) { + const parameterSchemaBytes = jsonSize(metadata.parameterSchema) + if (parameterSchemaBytes === null) { + return 'Tool parameter schema must be JSON serializable' + } + if (parameterSchemaBytes > MAX_MCP_PARAMETER_SCHEMA_BYTES) { + return `Tool parameter schema exceeds maximum size of ${MAX_MCP_PARAMETER_SCHEMA_BYTES} bytes` + } + } + + return null +} diff --git a/apps/sim/lib/mcp/workflow-mcp-sync.ts b/apps/sim/lib/mcp/workflow-mcp-sync.ts index 06b6e28e6a4..97c7032dac3 100644 --- a/apps/sim/lib/mcp/workflow-mcp-sync.ts +++ b/apps/sim/lib/mcp/workflow-mcp-sync.ts @@ -1,7 +1,20 @@ import { db, workflowMcpServer, workflowMcpTool } from '@sim/db' import { createLogger } from '@sim/logger' -import { and, eq, inArray, isNull } from 'drizzle-orm' +import { and, asc, eq, gt, inArray, isNull } from 'drizzle-orm' import type { DbOrTx } from '@/lib/db/types' +import { MAX_MCP_SERVERS_PER_WORKFLOW } from '@/lib/mcp/constants' +import { acquireWorkflowMcpServerLock } from '@/lib/mcp/server-locks' +import { + addMcpToolMetadataUsageRow, + createMcpToolMetadataUsageRow, + exceedsMcpServerToolMetadataBudget, + getMcpServerToolMetadataUsageRows, + getMcpToolMetadataUsageFromRows, + type McpToolMetadataUsage, + type McpToolMetadataUsageRow, + subtractMcpToolMetadataUsageRow, + validateMcpToolMetadataForStorage, +} from '@/lib/mcp/tool-limits' import { loadDeployedWorkflowState } from '@/lib/workflows/persistence/utils' import { hasValidStartBlockInState } from '@/lib/workflows/triggers/trigger-utils' import type { WorkflowState } from '@/stores/workflows/workflow/types' @@ -11,6 +24,82 @@ import { extractInputFormatFromBlocks, generateToolInputSchema } from './workflo const logger = createLogger('WorkflowMcpSync') const EMPTY_SCHEMA: Record = Object.freeze({ type: 'object', properties: {} }) +const MCP_SYNC_TOOLS_PAGE_SIZE = 100 + +class WorkflowMcpServerFanoutError extends Error { + constructor(workflowId: string) { + super( + `Workflow ${workflowId} is exposed on more than ${MAX_MCP_SERVERS_PER_WORKFLOW} MCP servers` + ) + this.name = 'WorkflowMcpServerFanoutError' + } +} + +interface WorkflowMcpToolSyncRow { + id: string + serverId: string + toolName: string + toolDescription: string | null +} + +interface ServerMetadataUsageState { + usageByToolId: Map + serverUsage: McpToolMetadataUsage +} + +async function listWorkflowMcpToolSyncPage( + tx: DbOrTx, + workflowId: string, + afterToolId?: string, + serverIds?: string[] +): Promise { + return tx + .select({ + id: workflowMcpTool.id, + serverId: workflowMcpTool.serverId, + toolName: workflowMcpTool.toolName, + toolDescription: workflowMcpTool.toolDescription, + }) + .from(workflowMcpTool) + .where( + and( + eq(workflowMcpTool.workflowId, workflowId), + isNull(workflowMcpTool.archivedAt), + serverIds && serverIds.length > 0 + ? inArray(workflowMcpTool.serverId, serverIds) + : undefined, + afterToolId ? gt(workflowMcpTool.id, afterToolId) : undefined + ) + ) + .orderBy(asc(workflowMcpTool.id)) + .limit(MCP_SYNC_TOOLS_PAGE_SIZE + 1) +} + +async function collectWorkflowMcpToolServerIds( + tx: DbOrTx, + workflowId: string +): Promise> { + const serverIds = new Set() + let afterToolId: string | undefined + + while (true) { + const page = await listWorkflowMcpToolSyncPage(tx, workflowId, afterToolId) + if (page.length === 0) break + + const pageTools = page.slice(0, MCP_SYNC_TOOLS_PAGE_SIZE) + for (const tool of pageTools) { + serverIds.add(tool.serverId) + if (serverIds.size > MAX_MCP_SERVERS_PER_WORKFLOW) { + throw new WorkflowMcpServerFanoutError(workflowId) + } + } + + if (page.length <= MCP_SYNC_TOOLS_PAGE_SIZE) break + afterToolId = pageTools.at(-1)?.id + } + + return [...serverIds].sort().map((serverId) => ({ serverId })) +} /** * Generate MCP tool parameter schema from workflow blocks. @@ -25,19 +114,14 @@ export function generateSchemaFromBlocks(blocks: Record): Recor /** * Load a workflow's active deployed state and generate its MCP parameter schema. - * Returns a proper JSON Schema derived from the start block's input format, - * or a fallback empty schema if the workflow has no inputs or no active deployment. + * Workflows with no inputs or no active deployment use an empty object schema. */ export async function generateParameterSchemaForWorkflow( workflowId: string ): Promise> { - try { - const deployed = await loadDeployedWorkflowState(workflowId) - if (!deployed?.blocks) return EMPTY_SCHEMA - return generateSchemaFromBlocks(deployed.blocks as Record) - } catch { - return EMPTY_SCHEMA - } + const deployed = await loadDeployedWorkflowState(workflowId) + if (!deployed?.blocks) return EMPTY_SCHEMA + return generateSchemaFromBlocks(deployed.blocks as Record) } interface SyncOptions { @@ -65,58 +149,153 @@ interface SyncOptions { export async function syncMcpToolsForWorkflow( options: SyncOptions ): Promise> { + if (!options.tx) { + const tools = await db.transaction((tx) => + syncMcpToolsForWorkflow({ ...options, tx, notify: false }) + ) + if (options.notify ?? true) notifyMcpToolServers(tools) + return tools + } + const { workflowId, requestId, state, context = 'sync', - tx = db, + tx, notify = true, throwOnError = false, } = options try { - const tools = await tx - .select({ id: workflowMcpTool.id, serverId: workflowMcpTool.serverId }) - .from(workflowMcpTool) - .where(and(eq(workflowMcpTool.workflowId, workflowId), isNull(workflowMcpTool.archivedAt))) - - if (tools.length === 0) { - return [] - } - let workflowState: { blocks?: Record } | null = state ?? null if (!workflowState) { workflowState = await loadDeployedWorkflowState(workflowId) } if (!hasValidStartBlockInState(workflowState as WorkflowState | null)) { - await tx.delete(workflowMcpTool).where(eq(workflowMcpTool.workflowId, workflowId)) - logger.info( - `[${requestId}] Removed ${tools.length} MCP tool(s) - workflow has no start block (${context}): ${workflowId}` - ) - if (notify) notifyMcpToolServers(tools) - return tools + const affectedTools = await removeMcpToolsForWorkflow(workflowId, requestId, tx, false, true) + if (notify) notifyMcpToolServers(affectedTools) + return affectedTools } - const parameterSchema = workflowState?.blocks + const generatedParameterSchema = workflowState?.blocks ? generateSchemaFromBlocks(workflowState.blocks) : EMPTY_SCHEMA + const schemaLimitError = validateMcpToolMetadataForStorage({ + parameterSchema: generatedParameterSchema, + }) + if (schemaLimitError) { + throw new Error(schemaLimitError) + } + const parameterSchema = generatedParameterSchema + + const affectedServerIds = new Set() + const lockedServers = await collectWorkflowMcpToolServerIds(tx, workflowId) + if (lockedServers.length === 0) return [] + + for (const { serverId } of lockedServers) { + await acquireWorkflowMcpServerLock(tx, serverId) + affectedServerIds.add(serverId) + } + const lockedServerIds = [...affectedServerIds] - await tx - .update(workflowMcpTool) - .set({ - parameterSchema, - updatedAt: new Date(), + const usageStateByServer = new Map() + for (const { serverId } of lockedServers) { + const rows = await getMcpServerToolMetadataUsageRows(tx, serverId) + usageStateByServer.set(serverId, { + usageByToolId: new Map(rows.map((row) => [row.id, row])), + serverUsage: getMcpToolMetadataUsageFromRows(rows), }) - .where(and(eq(workflowMcpTool.workflowId, workflowId), isNull(workflowMcpTool.archivedAt))) + } + + let syncedToolCount = 0 + let afterToolId: string | undefined + + while (true) { + const page = await listWorkflowMcpToolSyncPage(tx, workflowId, afterToolId, lockedServerIds) + if (page.length === 0) break + + const pageTools = page.slice(0, MCP_SYNC_TOOLS_PAGE_SIZE) + const toolsByServer = new Map() + for (const tool of pageTools) { + affectedServerIds.add(tool.serverId) + const serverTools = toolsByServer.get(tool.serverId) ?? [] + serverTools.push(tool) + toolsByServer.set(tool.serverId, serverTools) + } + + for (const [serverId, serverTools] of [...toolsByServer].sort(([left], [right]) => + left.localeCompare(right) + )) { + const usageState = usageStateByServer.get(serverId) + if (!usageState) { + throw new Error(`Missing locked MCP server usage state for server ${serverId}`) + } + const schemaToolIds: string[] = [] + const emptySchemaToolIds: string[] = [] + + for (const tool of serverTools) { + const existingUsage = subtractMcpToolMetadataUsageRow( + usageState.serverUsage, + usageState.usageByToolId.get(tool.id) + ) + const shouldUseEmptySchema = exceedsMcpServerToolMetadataBudget(existingUsage, { + toolName: tool.toolName, + toolDescription: tool.toolDescription, + parameterSchema, + }) + const schemaForTool = shouldUseEmptySchema ? EMPTY_SCHEMA : parameterSchema + + if (shouldUseEmptySchema) { + emptySchemaToolIds.push(tool.id) + } else { + schemaToolIds.push(tool.id) + } + + const updatedUsageRow = createMcpToolMetadataUsageRow({ + id: tool.id, + toolName: tool.toolName, + toolDescription: tool.toolDescription, + parameterSchema: schemaForTool, + }) + usageState.usageByToolId.set(tool.id, updatedUsageRow) + usageState.serverUsage = addMcpToolMetadataUsageRow(existingUsage, updatedUsageRow) + } + + if (schemaToolIds.length > 0) { + await tx + .update(workflowMcpTool) + .set({ + parameterSchema, + updatedAt: new Date(), + }) + .where(inArray(workflowMcpTool.id, schemaToolIds)) + } + + if (emptySchemaToolIds.length > 0) { + await tx + .update(workflowMcpTool) + .set({ + parameterSchema: EMPTY_SCHEMA, + updatedAt: new Date(), + }) + .where(inArray(workflowMcpTool.id, emptySchemaToolIds)) + } + } + + syncedToolCount += pageTools.length + if (page.length <= MCP_SYNC_TOOLS_PAGE_SIZE) break + afterToolId = pageTools.at(-1)?.id + } logger.info( - `[${requestId}] Synced ${tools.length} MCP tool(s) for workflow (${context}): ${workflowId}` + `[${requestId}] Synced ${syncedToolCount} MCP tool(s) for workflow (${context}): ${workflowId}` ) - if (notify) notifyMcpToolServers(tools) - return tools + const affectedTools = [...affectedServerIds].map((serverId) => ({ serverId })) + if (notify) notifyMcpToolServers(affectedTools) + return affectedTools } catch (error) { logger.error(`[${requestId}] Error syncing MCP tools (${context}):`, error) if (throwOnError) throw error @@ -131,18 +310,27 @@ export async function syncMcpToolsForWorkflow( export async function removeMcpToolsForWorkflow( workflowId: string, requestId: string, - tx: DbOrTx = db, + tx?: DbOrTx, notify = true, throwOnError = false ): Promise> { + if (!tx) { + const tools = await db.transaction((transaction) => + removeMcpToolsForWorkflow(workflowId, requestId, transaction, false, throwOnError) + ) + if (notify) notifyMcpToolServers(tools) + return tools + } + try { - const tools = await tx - .select({ id: workflowMcpTool.id, serverId: workflowMcpTool.serverId }) - .from(workflowMcpTool) - .where(and(eq(workflowMcpTool.workflowId, workflowId), isNull(workflowMcpTool.archivedAt))) + const tools = await collectWorkflowMcpToolServerIds(tx, workflowId) if (tools.length === 0) return [] + for (const { serverId } of tools) { + await acquireWorkflowMcpServerLock(tx, serverId) + } + await tx.delete(workflowMcpTool).where(eq(workflowMcpTool.workflowId, workflowId)) logger.info(`[${requestId}] Removed MCP tools for workflow: ${workflowId}`) diff --git a/apps/sim/lib/workflows/deployment-outbox.ts b/apps/sim/lib/workflows/deployment-outbox.ts index fc7711b5892..b33003f7a47 100644 --- a/apps/sim/lib/workflows/deployment-outbox.ts +++ b/apps/sim/lib/workflows/deployment-outbox.ts @@ -10,6 +10,7 @@ import { } from '@/lib/core/outbox/service' import { generateRequestId } from '@/lib/core/utils/request' import { getBaseUrl } from '@/lib/core/utils/urls' +import { setWorkflowMcpTransactionLockTimeout } from '@/lib/mcp/server-locks' import { notifyMcpToolServers, removeMcpToolsForWorkflow, @@ -450,12 +451,14 @@ async function removeMcpToolsIfStillUndeployed( requestId: string ): Promise { const tools = await db.transaction(async (tx) => { + await setWorkflowMcpTransactionLockTimeout(tx) + const [workflowRecord] = await tx .select({ id: workflowTable.id, isDeployed: workflowTable.isDeployed }) .from(workflowTable) .where(eq(workflowTable.id, workflowId)) - .limit(1) .for('update') + .limit(1) if (!workflowRecord || workflowRecord.isDeployed) return [] return removeMcpToolsForWorkflow(workflowId, requestId, tx, false, true) @@ -497,12 +500,14 @@ async function syncMcpToolsIfStillActive(params: { state: { blocks?: Record } }): Promise { const tools = await db.transaction(async (tx) => { + await setWorkflowMcpTransactionLockTimeout(tx) + const [workflowRecord] = await tx .select({ id: workflowTable.id }) .from(workflowTable) .where(eq(workflowTable.id, params.workflowId)) - .limit(1) .for('update') + .limit(1) if (!workflowRecord) return [] diff --git a/packages/testing/src/mocks/database.mock.ts b/packages/testing/src/mocks/database.mock.ts index 6abcb8ac341..dead81be44c 100644 --- a/packages/testing/src/mocks/database.mock.ts +++ b/packages/testing/src/mocks/database.mock.ts @@ -97,19 +97,27 @@ export function createMockSqlOperators() { * ``` */ const limit = vi.fn(() => Promise.resolve([] as unknown[])) -const orderBy = vi.fn(() => Promise.resolve([] as unknown[])) const returning = vi.fn(() => Promise.resolve([] as unknown[])) -const groupBy = vi.fn(() => Promise.resolve([] as unknown[])) const execute = vi.fn(() => Promise.resolve([] as unknown[])) -const forBuilder = () => { +const terminalBuilder = () => { const thenable: any = Promise.resolve([] as unknown[]) thenable.limit = limit thenable.orderBy = orderBy thenable.returning = returning thenable.groupBy = groupBy + thenable.for = forClause return thenable } + +const orderBy = vi.fn(terminalBuilder) +const having = vi.fn(terminalBuilder) +const groupBy = vi.fn(() => { + const builder = terminalBuilder() + builder.having = having + return builder +}) +const forBuilder = terminalBuilder const forClause = vi.fn(forBuilder) const onConflictDoUpdate = vi.fn(() => ({ returning }) as unknown as Promise) @@ -162,6 +170,7 @@ export const dbChainMockFns = { innerJoin, leftJoin, groupBy, + having, execute, for: forClause, insert, @@ -199,9 +208,14 @@ export function resetDbChainMock(): void { set.mockImplementation(() => ({ where })) del.mockImplementation(() => ({ where })) limit.mockImplementation(() => Promise.resolve([] as unknown[])) - orderBy.mockImplementation(() => Promise.resolve([] as unknown[])) + orderBy.mockImplementation(terminalBuilder) returning.mockImplementation(() => Promise.resolve([] as unknown[])) - groupBy.mockImplementation(() => Promise.resolve([] as unknown[])) + having.mockImplementation(terminalBuilder) + groupBy.mockImplementation(() => { + const builder = terminalBuilder() + builder.having = having + return builder + }) execute.mockImplementation(() => Promise.resolve([] as unknown[])) forClause.mockImplementation(forBuilder) transaction.mockImplementation(async (cb: (tx: typeof dbChainMock.db) => unknown) =>