diff --git a/packages/db/prisma/migrations/20260529214711_add_mcp_connectors_tables/migration.sql b/packages/db/prisma/migrations/20260529214711_add_mcp_connectors_tables/migration.sql new file mode 100644 index 000000000..fb82482c0 --- /dev/null +++ b/packages/db/prisma/migrations/20260529214711_add_mcp_connectors_tables/migration.sql @@ -0,0 +1,66 @@ +-- CreateEnum +CREATE TYPE "McpServerClientInfoSource" AS ENUM ('DYNAMIC', 'STATIC'); + +-- CreateTable +CREATE TABLE "McpServer" ( + "id" TEXT NOT NULL, + "name" TEXT NOT NULL, + "sanitizedName" TEXT NOT NULL, + "serverUrl" TEXT NOT NULL, + "clientInfo" TEXT, + "clientInfoSource" "McpServerClientInfoSource" NOT NULL DEFAULT 'DYNAMIC', + "orgId" INTEGER NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "McpServer_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "McpServerToolCallCount" ( + "mcpServerId" TEXT NOT NULL, + "toolName" TEXT NOT NULL, + "count" INTEGER NOT NULL DEFAULT 0, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "McpServerToolCallCount_pkey" PRIMARY KEY ("mcpServerId","toolName") +); + +-- CreateTable +CREATE TABLE "UserMcpServer" ( + "userId" TEXT NOT NULL, + "serverId" TEXT NOT NULL, + "tokens" TEXT, + "tokensExpiresAt" TIMESTAMP(3), + "codeVerifier" TEXT, + "state" TEXT, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "UserMcpServer_pkey" PRIMARY KEY ("userId","serverId") +); + +-- CreateIndex +CREATE UNIQUE INDEX "McpServer_serverUrl_orgId_key" ON "McpServer"("serverUrl", "orgId"); + +-- CreateIndex +CREATE UNIQUE INDEX "McpServer_orgId_sanitizedName_key" ON "McpServer"("orgId", "sanitizedName"); + +-- CreateIndex +CREATE INDEX "UserMcpServer_serverId_idx" ON "UserMcpServer"("serverId"); + +-- CreateIndex +CREATE INDEX "UserMcpServer_state_idx" ON "UserMcpServer"("state"); + +-- AddForeignKey +ALTER TABLE "McpServer" ADD CONSTRAINT "McpServer_orgId_fkey" FOREIGN KEY ("orgId") REFERENCES "Org"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "McpServerToolCallCount" ADD CONSTRAINT "McpServerToolCallCount_mcpServerId_fkey" FOREIGN KEY ("mcpServerId") REFERENCES "McpServer"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_serverId_fkey" FOREIGN KEY ("serverId") REFERENCES "McpServer"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index be72df31d..e0371e56c 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -294,6 +294,8 @@ model Org { chats Chat[] repoVisits RepoVisit[] + mcpServers McpServer[] + license License? } @@ -340,6 +342,11 @@ enum OrgRole { MEMBER } +enum McpServerClientInfoSource { + DYNAMIC + STATIC +} + model UserToOrg { joinedAt DateTime @default(now()) @@ -422,6 +429,8 @@ model User { /// claim baked into the JWT cookie at mint time. sessionVersion Int @default(0) + userMcpServers UserMcpServer[] + createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -656,3 +665,73 @@ model ChangelogEntry { @@index([publishedAt]) } + +/// An external MCP server endpoint, unique per org. +/// Stores the dynamic client registration (client_id/client_secret) once per org. +model McpServer { + id String @id @default(cuid()) + name String /// Org-approved display name (e.g., "Linear") + sanitizedName String /// Stable tool-name prefix (e.g., "linear") + serverUrl String /// MCP server endpoint (e.g., "https://mcp.linear.app/mcp") + + /// Dynamic client registration result (RFC 7591) or admin-provided static OAuth client credentials. + /// Encrypted JSON of OAuthClientInformation: { client_id, client_secret, client_id_issued_at, client_secret_expires_at } + /// Null for DYNAMIC rows until first user in the org triggers registration. + clientInfo String? + clientInfoSource McpServerClientInfoSource @default(DYNAMIC) + + org Org @relation(fields: [orgId], references: [id], onDelete: Cascade) + orgId Int + + userMcpServers UserMcpServer[] + toolCallCounts McpServerToolCallCount[] + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([serverUrl, orgId]) + @@unique([orgId, sanitizedName]) +} + +/// Lifetime tool call counters for an MCP server. +model McpServerToolCallCount { + mcpServer McpServer @relation(fields: [mcpServerId], references: [id], onDelete: Cascade) + mcpServerId String + toolName String + count Int @default(0) + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@id([mcpServerId, toolName]) +} + +/// A user's personal connection to an MCP server. +/// Stores per-user OAuth tokens and ephemeral auth-flow state. +model UserMcpServer { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + + server McpServer @relation(fields: [serverId], references: [id], onDelete: Cascade) + serverId String + + /// OAuth tokens (access_token, refresh_token, etc.) — encrypted JSON of OAuthTokens. + tokens String? + + /// Absolute expiry time of the access token, computed at issuance from expires_in. + /// Null when no tokens are stored or the provider did not include expires_in. + tokensExpiresAt DateTime? + + /// PKCE code verifier — ephemeral, only used between redirect and callback. + codeVerifier String? + + /// OAuth state parameter — ephemeral, for CSRF protection during auth flow. + state String? + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@id([userId, serverId]) + @@index([serverId]) + @@index([state]) +} diff --git a/packages/shared/src/entitlements.test.ts b/packages/shared/src/entitlements.test.ts index 906ef516e..35fe4a65d 100644 --- a/packages/shared/src/entitlements.test.ts +++ b/packages/shared/src/entitlements.test.ts @@ -65,6 +65,16 @@ const makeLicense = (overrides: Partial = {}): License => ({ cancelAt: null, trialEnd: null, hasPaymentMethod: null, + yearlyTermStartedAt: null, + yearlyTermEndsAt: null, + yearlyTotalQuartersInTerm: null, + yearlyCurrentQuarterNumber: null, + yearlyCurrentQuarterStartedAt: null, + yearlyCurrentQuarterEndsAt: null, + yearlyCommittedSeats: null, + yearlyOverageSeats: null, + yearlyBillableOverageSeats: null, + yearlyPeakSeats: null, lastSyncAt: new Date(), lastSyncErrorCode: null, createdAt: new Date(), diff --git a/packages/shared/src/env.server.test.ts b/packages/shared/src/env.server.test.ts index bb7c7acc3..7f9bf0bca 100644 --- a/packages/shared/src/env.server.test.ts +++ b/packages/shared/src/env.server.test.ts @@ -54,3 +54,31 @@ describe('PERMISSION_SYNC_ENABLED', () => { expect(env.PERMISSION_SYNC_ENABLED).toBe('false'); }); }); + +describe('SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS', () => { + beforeEach(() => { + vi.resetModules(); + delete process.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + }); + + afterEach(() => { + delete process.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + }); + + test('defaults to 60000 when not set', async () => { + const { env } = await import('./env.server.js'); + expect(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS).toBe(60000); + }); + + test('accepts positive integers', async () => { + process.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS = '5000'; + const { env } = await import('./env.server.js'); + expect(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS).toBe(5000); + }); + + test.each(['0', '-1', '1.5', '2147483648', String(Number.MAX_SAFE_INTEGER + 1)])('rejects %s', async (timeoutMs) => { + process.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS = timeoutMs; + + await expect(import('./env.server.js')).rejects.toThrow(); + }); +}); diff --git a/packages/shared/src/env.server.ts b/packages/shared/src/env.server.ts index d9ee5cae3..cbb596918 100644 --- a/packages/shared/src/env.server.ts +++ b/packages/shared/src/env.server.ts @@ -14,6 +14,7 @@ const booleanSchema = z.enum(["true", "false"]); // coerce helps us convert them to numbers. // @see: https://zod.dev/?id=coercion-for-primitives const numberSchema = z.coerce.number(); +const maxTimerDelayMs = 2_147_483_647; const ajv = new Ajv({ validateFormats: false, @@ -282,6 +283,7 @@ const options = { */ SOURCEBOT_CHAT_MODEL_TEMPERATURE: numberSchema.optional(), SOURCEBOT_CHAT_MAX_STEP_COUNT: numberSchema.default(100), + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: numberSchema.int().positive().max(maxTimerDelayMs).default(60000), DEBUG_WRITE_CHAT_MESSAGES_TO_FILE: booleanSchema.default('false'), DEBUG_ENABLE_REACT_SCAN: booleanSchema.default('false'), diff --git a/packages/web/package.json b/packages/web/package.json index dfed8625f..cd06162c7 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -20,6 +20,7 @@ "@ai-sdk/deepseek": "^2.0.29", "@ai-sdk/google": "^3.0.64", "@ai-sdk/google-vertex": "^4.0.111", + "@ai-sdk/mcp": "^2.0.0-beta.11", "@ai-sdk/mistral": "^3.0.30", "@ai-sdk/openai": "^3.0.53", "@ai-sdk/openai-compatible": "^2.0.41", @@ -196,7 +197,7 @@ "use-stick-to-bottom": "^1.1.3", "usehooks-ts": "^3.1.0", "vscode-icons-js": "^11.6.1", - "zod": "^3.25.74", + "zod": "^3.25.76", "zod-to-json-schema": "^3.24.5" }, "devDependencies": { diff --git a/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx b/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx index 8eb817e55..11960732a 100644 --- a/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/settingsSidebar/nav.tsx @@ -12,22 +12,24 @@ import { import { useEntitlements } from "@/features/entitlements/useEntitlements"; import { Entitlement } from "@sourcebot/shared"; import { + BotIcon, ChartAreaIcon, KeyRoundIcon, LinkIcon, type LucideIcon, PlugIcon, ScrollTextIcon, + ServerIcon, Settings2Icon, ShieldIcon, UserIcon, UsersIcon, } from "lucide-react"; +import { IconType } from "react-icons/lib"; import { VscMcp } from "react-icons/vsc"; import Link from "next/link"; import { usePathname } from "next/navigation"; import { UpgradeBadge } from "../upgradeBadge"; -import { IconType } from "react-icons/lib"; const iconMap = { "link": LinkIcon, @@ -37,9 +39,11 @@ const iconMap = { "plug": PlugIcon, "chart-area": ChartAreaIcon, "scroll-text": ScrollTextIcon, + "server": ServerIcon, "settings": Settings2Icon, "user": UserIcon, "mcp": VscMcp, + "bot": BotIcon, } satisfies Record; export type NavIconName = keyof typeof iconMap; diff --git a/packages/web/src/app/(app)/@sidebar/components/sidebarBase.tsx b/packages/web/src/app/(app)/@sidebar/components/sidebarBase.tsx index 19db523f8..6083c12e4 100644 --- a/packages/web/src/app/(app)/@sidebar/components/sidebarBase.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/sidebarBase.tsx @@ -339,4 +339,4 @@ const AppearanceDropdownMenuGroup = () => { ) -} \ No newline at end of file +} diff --git a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx index 5a8a92abc..fe50bc5ca 100644 --- a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx +++ b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx @@ -68,7 +68,7 @@ export const LandingPage = ({
{ - createNewChatThread(children, selectedSearchScopes); + createNewChatThread(children, selectedSearchScopes, []); }} className="min-h-[50px]" isRedirecting={isLoading} diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx new file mode 100644 index 000000000..cc3391dc0 --- /dev/null +++ b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx @@ -0,0 +1,89 @@ +import { cleanup, render, waitFor } from '@testing-library/react'; +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; +import { SET_CHAT_STATE_SESSION_STORAGE_KEY } from '@/features/chat/constants'; +import { ChatThreadPanel } from './chatThreadPanel'; + +const { chatThreadProps } = vi.hoisted(() => ({ + chatThreadProps: [] as Array<{ disabledMcpServerIds?: unknown }>, +})); + +vi.mock('next/navigation', () => ({ + useParams: () => ({ id: 'chat-1' }), +})); + +vi.mock('@/features/chat/components/chatThread', () => ({ + ChatThread: (props: { disabledMcpServerIds?: unknown }) => { + chatThreadProps.push(props); + return
; + }, +})); + +function createMockStorage(): Storage { + const store = new Map(); + + return { + get length() { + return store.size; + }, + clear: () => store.clear(), + getItem: (key: string) => store.get(key) ?? null, + key: (index: number) => Array.from(store.keys())[index] ?? null, + removeItem: (key: string) => { + store.delete(key); + }, + setItem: (key: string, value: string) => { + store.set(key, value); + }, + }; +} + +function installMockStorage(key: 'localStorage' | 'sessionStorage') { + const storage = createMockStorage(); + Object.defineProperty(window, key, { + configurable: true, + value: storage, + }); + Object.defineProperty(globalThis, key, { + configurable: true, + value: storage, + }); +} + +describe('ChatThreadPanel', () => { + beforeEach(() => { + installMockStorage('localStorage'); + installMockStorage('sessionStorage'); + chatThreadProps.length = 0; + sessionStorage.clear(); + }); + + afterEach(() => { + cleanup(); + sessionStorage.clear(); + }); + + test('defaults restored disabled MCP server ids to an empty array when missing from session storage', async () => { + sessionStorage.setItem(SET_CHAT_STATE_SESSION_STORAGE_KEY, JSON.stringify({ + inputMessage: { + role: 'user', + parts: [{ type: 'text', text: 'hello' }], + }, + selectedSearchScopes: [], + })); + + render( + + ); + + await waitFor(() => expect(chatThreadProps.length).toBeGreaterThan(1)); + + expect(chatThreadProps.at(-1)?.disabledMcpServerIds).toEqual([]); + }); +}); diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx index cd1d16b2f..33808b486 100644 --- a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx +++ b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx @@ -42,11 +42,13 @@ export const ChatThreadPanel = ({ localStorage.removeItem(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY); }, []); - // Use the last user's last message to determine what repos and contexts we should select by default. + // Use the last user message to determine what repos, contexts, and MCP state we should select by default. const lastUserMessage = messages.findLast((message) => message.role === "user"); const defaultSelectedSearchScopes = lastUserMessage?.metadata?.selectedSearchScopes ?? []; + const defaultDisabledMcpServerIds = lastUserMessage?.metadata?.disabledMcpServerIds ?? []; const [selectedSearchScopes, setSelectedSearchScopes] = useState(defaultSelectedSearchScopes); - + const [disabledMcpServerIds, setDisabledMcpServerIds] = useState(defaultDisabledMcpServerIds); + useEffect(() => { if (!chatState) { return; @@ -55,6 +57,7 @@ export const ChatThreadPanel = ({ try { setInputMessage(chatState.inputMessage); setSelectedSearchScopes(chatState.selectedSearchScopes); + setDisabledMcpServerIds(chatState.disabledMcpServerIds); } catch { console.error('Invalid chat state in session storage'); } finally { @@ -74,6 +77,8 @@ export const ChatThreadPanel = ({ searchContexts={searchContexts} selectedSearchScopes={selectedSearchScopes} onSelectedSearchScopesChange={setSelectedSearchScopes} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} isOwner={isOwner} isAuthenticated={isAuthenticated} isLoginWallEnabled={isLoginWallEnabled} @@ -81,4 +86,4 @@ export const ChatThreadPanel = ({ />
) -} \ No newline at end of file +} diff --git a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx index d33d2e5b4..55b2d56d5 100644 --- a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx +++ b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx @@ -8,7 +8,7 @@ import { useCreateNewChatThread } from "@/features/chat/useCreateNewChatThread"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useState } from "react"; import { useLocalStorage } from "usehooks-ts"; -import { SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; +import { DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; import { SearchModeSelector } from "../../components/searchModeSelector"; import { NotConfiguredErrorBanner } from "@/features/chat/components/notConfiguredErrorBanner"; @@ -29,6 +29,7 @@ export const LandingPageChatBox = ({ }: LandingPageChatBox) => { const { createNewChatThread, isLoading } = useCreateNewChatThread(); const [selectedSearchScopes, setSelectedSearchScopes] = useLocalStorage(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); + const [disabledMcpServerIds, setDisabledMcpServerIds] = useLocalStorage(DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); const [isContextSelectorOpen, setIsContextSelectorOpen] = useState(false); const isChatBoxDisabled = languageModels.length === 0; @@ -37,7 +38,7 @@ export const LandingPageChatBox = ({
{ - createNewChatThread(children); + createNewChatThread(children, selectedSearchScopes, disabledMcpServerIds); }} className="min-h-[50px]" isRedirecting={isLoading} @@ -59,6 +60,8 @@ export const LandingPageChatBox = ({ onSelectedSearchScopesChange={setSelectedSearchScopes} isContextSelectorOpen={isContextSelectorOpen} onContextSelectorOpenChanged={setIsContextSelectorOpen} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} /> { + if (didHandleStatusRef.current) { + return; + } + + const status = searchParams.get('status'); + if (status !== 'connected' && status !== 'error') { + return; + } + + didHandleStatusRef.current = true; + const server = searchParams.get('server'); + const message = searchParams.get('message'); + + if (status === 'connected') { + toast({ description: `Successfully connected${server ? ` to ${server}` : ''}.` }); + } else { + toast({ + title: "Connection failed", + description: message ?? 'Failed to connect connector.', + variant: "destructive", + }); + } + + const nextSearchParams = new URLSearchParams(searchParams.toString()); + nextSearchParams.delete('status'); + nextSearchParams.delete('server'); + nextSearchParams.delete('message'); + + const query = nextSearchParams.toString(); + router.replace(`${pathname}${query ? `?${query}` : ''}`, { scroll: false }); + }, [pathname, router, searchParams, toast]); + + return null; +} diff --git a/packages/web/src/app/(app)/chat/layout.tsx b/packages/web/src/app/(app)/chat/layout.tsx index 6f2094209..b4bdcdda5 100644 --- a/packages/web/src/app/(app)/chat/layout.tsx +++ b/packages/web/src/app/(app)/chat/layout.tsx @@ -1,6 +1,8 @@ import { AGENTIC_SEARCH_TUTORIAL_DISMISSED_COOKIE_NAME } from '@/lib/constants'; import { NavigationGuardProvider } from 'next-navigation-guard'; import { cookies } from 'next/headers'; +import { Suspense } from 'react'; +import { McpOAuthStatusToast } from './components/mcpOAuthStatusToast'; import { TutorialDialog } from './components/tutorialDialog'; interface LayoutProps { @@ -14,8 +16,11 @@ export default async function Layout({ children }: LayoutProps) { // @note: we use a navigation guard here since we don't support resuming streams yet. // @see: https://ai-sdk.dev/docs/ai-sdk-ui/chatbot-message-persistence#resuming-ongoing-streams + + + {children} ) -} \ No newline at end of file +} diff --git a/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.test.tsx b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.test.tsx new file mode 100644 index 000000000..2da1a2fa9 --- /dev/null +++ b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.test.tsx @@ -0,0 +1,50 @@ +import { afterEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, render, screen } from '@testing-library/react'; + +vi.mock('@/app/api/(client)/client', () => ({ + getMcpServersWithStatus: vi.fn(), + getMcpServerTools: vi.fn(), +})); +vi.mock('@/ee/features/chat/mcp/actions', () => ({ + disconnectMcpServer: vi.fn(), +})); + +const { AccountAskAgentEmptyState, AccountAskAgentOAuthUnavailableState } = await import('./accountAskAgentPage'); + +afterEach(() => { + cleanup(); +}); + +describe('AccountAskAgentEmptyState', () => { + test('points owners to workspace Ask Agent settings', () => { + render(); + + expect(screen.getByText('No connectors configured yet')).toBeTruthy(); + expect(screen.getByText('Open Workspace Ask Agent to approve connectors for your workspace.')).toBeTruthy(); + expect(screen.getByRole('link', { name: /Open Workspace Ask Agent/ }).getAttribute('href')).toBe('/settings/workspaceAskAgent'); + }); + + test('tells members to contact an admin', () => { + render(); + + expect(screen.getByText('No connectors available')).toBeTruthy(); + expect(screen.getByText(/Contact your workspace admin/)).toBeTruthy(); + expect(screen.queryByRole('link', { name: /Open Workspace Ask Agent/ })).toBeNull(); + }); +}); + +describe('AccountAskAgentOAuthUnavailableState', () => { + test('points owners to workspace cleanup settings', () => { + render(); + + expect(screen.getByText('Connector OAuth is unavailable')).toBeTruthy(); + expect(screen.getByRole('link', { name: /Open Workspace Ask Agent/ }).getAttribute('href')).toBe('/settings/workspaceAskAgent'); + }); + + test('hides workspace cleanup link from members', () => { + render(); + + expect(screen.getByText('Connector setup is unavailable on this Sourcebot instance.')).toBeTruthy(); + expect(screen.queryByRole('link', { name: /Open Workspace Ask Agent/ })).toBeNull(); + }); +}); diff --git a/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.tsx b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.tsx new file mode 100644 index 000000000..4f69fcf1d --- /dev/null +++ b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentPage.tsx @@ -0,0 +1,549 @@ +'use client'; + +import { useEffect, useMemo, useRef, useState } from "react"; +import Link from "next/link"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { CableIcon, ExternalLink, MoreHorizontal, SearchIcon, Settings2Icon, Unplug } from "lucide-react"; +import { getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { + DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { ConnectMcpButton } from "@/ee/features/chat/mcp/components/connectMcpButton"; +import { ConnectorCard } from "@/ee/features/chat/mcp/components/connectorCard"; +import { ConnectorRowInfo } from "@/ee/features/chat/mcp/components/connectorRowInfo"; +import { ConnectorToolTrigger } from "@/ee/features/chat/mcp/components/connectorToolDisclosure"; +import { useConnectMcp } from "@/ee/features/chat/mcp/hooks/useConnectMcp"; +import { useMcpToolMetadata } from "@/ee/features/chat/mcp/hooks/useMcpToolMetadata"; +import { disconnectMcpServer } from "@/ee/features/chat/mcp/actions"; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; +import { pluralize } from "@/ee/features/chat/mcp/utils"; +import { cn, isServiceError } from "@/lib/utils"; +import type { McpServerWithStatus } from "@/app/api/(server)/ee/askmcp/servers/route"; +import type { ServerToolsEntry } from "@/ee/features/chat/mcp/types"; + +type FilterTab = "all" | "connected"; + +function clearCallbackParams() { + const url = new URL(window.location.href); + url.searchParams.delete('status'); + url.searchParams.delete('server'); + url.searchParams.delete('message'); + window.history.replaceState({}, '', url.toString()); +} + +interface AccountAskAgentPageProps { + callbackStatus?: string; + callbackServer?: string; + callbackMessage?: string; + canManageConnectors: boolean; + isOAuthAvailable: boolean; +} + +export function AccountAskAgentEmptyState({ canManageConnectors }: { canManageConnectors: boolean }) { + return ( + + +
+ +
+

+ {canManageConnectors ? "No connectors configured yet" : "No connectors available"} +

+

+ {canManageConnectors + ? "Open Workspace Ask Agent to approve connectors for your workspace." + : "No connectors have been approved for this workspace yet. Contact your workspace admin."} +

+ {canManageConnectors && ( + + )} +
+
+ ); +} + +export function AccountAskAgentOAuthUnavailableState({ canManageConnectors }: { canManageConnectors: boolean }) { + return ( + + +
+ +
+

Connector OAuth is unavailable

+

+ {canManageConnectors + ? "Open Workspace Ask Agent to remove existing connector approvals and stored credentials." + : "Connector setup is unavailable on this Sourcebot instance."} +

+ {canManageConnectors && ( + + )} +
+
+ ); +} + +interface AccountConnectedConnectorCardProps { + server: McpServerWithStatus; + toolEntry?: ServerToolsEntry; + isToolsLoading: boolean; + isToolsError: boolean; + onRetryTools: () => void; + onReconnect: (serverId: string) => void; + onDisconnect: (server: McpServerWithStatus) => void; + disconnectingServerId: string | null; +} + +function AccountConnectedConnectorCard({ + server, + toolEntry, + isToolsLoading, + isToolsError, + onRetryTools, + onReconnect, + onDisconnect, + disconnectingServerId, +}: AccountConnectedConnectorCardProps) { + return ( + + {server.isConnected && ( + + + Connected + + )} + {server.isAuthExpired && ( + + + Authorization expired + + )} + + } + actionButtons={ + + + + + + onReconnect(server.id)}> + + Reconnect + + onDisconnect(server)} + > + + {disconnectingServerId === server.id ? "Disconnecting..." : "Disconnect"} + + + + } + /> + ); +} + +function AccountSuggestedConnectorCard({ server }: { server: McpServerWithStatus }) { + return ( + + + +
+ +
+
+ +
+
+ ); +} + +export function AccountAskAgentPage({ + callbackStatus, + callbackServer, + callbackMessage, + canManageConnectors, + isOAuthAvailable, +}: AccountAskAgentPageProps) { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const didHandleCallbackRef = useRef(false); + const [searchQuery, setSearchQuery] = useState(""); + const [activeTab, setActiveTab] = useState("all"); + const [disconnectingServerId, setDisconnectingServerId] = useState(null); + const [confirmDisconnectServer, setConfirmDisconnectServer] = useState<{ id: string; name: string } | null>(null); + const { connect: reconnectMcp } = useConnectMcp(); + + useEffect(() => { + if (didHandleCallbackRef.current) { + return; + } + if (callbackStatus === 'connected') { + didHandleCallbackRef.current = true; + toast({ description: `Successfully connected${callbackServer ? ` to ${callbackServer}` : ''}.` }); + clearCallbackParams(); + } else if (callbackStatus === 'error') { + didHandleCallbackRef.current = true; + toast({ title: "Connection failed", description: callbackMessage ?? 'Failed to connect connector.', variant: "destructive" }); + clearCallbackParams(); + } + }, [callbackStatus, callbackServer, callbackMessage, toast]); + + const { data: servers = [], isLoading, isError } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load connectors"); + } + return result; + }, + enabled: isOAuthAvailable, + }); + + const connectedServers = useMemo( + () => servers.filter((s) => s.isConnected || s.isAuthExpired), + [servers], + ); + + const suggestedServers = useMemo( + () => servers.filter((s) => !s.isConnected && !s.isAuthExpired), + [servers], + ); + + const filteredConnected = useMemo(() => { + const list = connectedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [connectedServers, searchQuery]); + + const filteredSuggested = useMemo(() => { + const list = suggestedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [suggestedServers, searchQuery]); + + const visibleConnected = filteredConnected; + const visibleSuggested = activeTab === "all" ? filteredSuggested : []; + const activeConnectedServerCount = useMemo( + () => servers.filter((s) => s.isConnected).length, + [servers], + ); + const { + isToolsLoading, + isToolsError, + refetchTools, + toolsByServerId, + } = useMcpToolMetadata(isOAuthAvailable, activeConnectedServerCount); + + const handleDisconnect = async (serverId: string) => { + setDisconnectingServerId(serverId); + setConfirmDisconnectServer(null); + try { + const result = await disconnectMcpServer(serverId, 'account_settings'); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to disconnect: ${result.message}`, variant: "destructive" }); + return; + } + toast({ description: "Connector disconnected." }); + await invalidateMcpConfigurationQueries(queryClient); + } catch { + toast({ title: "Error", description: "Failed to disconnect connector.", variant: "destructive" }); + } finally { + setDisconnectingServerId(null); + } + }; + + if (!isOAuthAvailable) { + return ( +
+
+

Ask Agent

+

+ Manage your personal Ask Agent setup. +

+
+ +
+
+

Connectors

+

+ Manage workspace-approved connectors for use with Ask Agent. +

+
+ +
+
+ ); + } + + if (isError) { + return
Error loading connectors
; + } + + if (!isLoading && servers.length === 0) { + return ( +
+
+

Ask Agent

+

+ Manage your personal Ask Agent setup. +

+
+ +
+
+

Connectors

+

+ Manage workspace-approved connectors for use with Ask Agent. +

+
+ +
+
+ ); + } + + return ( +
+
+

Ask Agent

+

+ Manage your personal Ask Agent setup. +

+
+ + + +
+
+

Connectors

+

+ Manage workspace-approved connectors for use with Ask Agent. +

+
+ + {/* Search + filter bar */} +
+
+ + setSearchQuery(e.target.value)} + className="pl-9" + /> +
+
+ + +
+
+
+ + {isLoading ? ( +
+ {Array.from({ length: 3 }).map((_, index) => ( + + + +
+ + +
+ +
+
+ ))} +
+ ) : ( + <> + {/* Connected section */} +
+
+

+ Connected +

+

+ {connectedServers.length} {pluralize(connectedServers.length, "connector")} +

+
+ + {visibleConnected.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No connected connectors match your search." + : "No connectors connected yet."} +

+
+
+ ) : ( + visibleConnected.map((server) => ( + { void refetchTools(); }} + onReconnect={reconnectMcp} + onDisconnect={(serverToDisconnect) => setConfirmDisconnectServer({ + id: serverToDisconnect.id, + name: serverToDisconnect.name || serverToDisconnect.serverUrl, + })} + disconnectingServerId={disconnectingServerId} + /> + )) + )} +
+ + {/* Suggested section */} + {activeTab === "all" && ( +
+
+

+ Suggested +

+

+ workspace-approved +

+
+ + {visibleSuggested.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No suggested connectors match your search." + : "All connectors are connected."} +

+
+
+ ) : ( + visibleSuggested.map((server) => ( + + )) + )} +
+ )} + + )} + + {/* Disconnect confirmation dialog */} + { + if (!open) { + setConfirmDisconnectServer(null); + } + }} + > + + + Disconnect Connector + + Are you sure you want to disconnect from {confirmDisconnectServer?.name}? Your stored credentials for this connector will be removed. + + + + Cancel + { + if (confirmDisconnectServer) { + handleDisconnect(confirmDisconnectServer.id); + } + }} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + Disconnect + + + + +
+ ); +} diff --git a/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx b/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx new file mode 100644 index 000000000..078e67288 --- /dev/null +++ b/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx @@ -0,0 +1,27 @@ +import { AccountAskAgentPage } from "./accountAskAgentPage"; +import { hasEntitlement } from "@/lib/entitlements"; +import { authenticatedPage } from "@/middleware/authenticatedPage"; +import { OrgRole } from "@sourcebot/db"; + +interface PageProps extends Record { + searchParams: Promise<{ + status?: string; + server?: string; + message?: string; + }>; +} + +export default authenticatedPage(async ({ role }, { searchParams }) => { + const { status, server, message } = await searchParams; + const isOAuthAvailable = await hasEntitlement('oauth'); + + return ( + + ); +}); diff --git a/packages/web/src/app/(app)/settings/layout.tsx b/packages/web/src/app/(app)/settings/layout.tsx index a98d22942..604601027 100644 --- a/packages/web/src/app/(app)/settings/layout.tsx +++ b/packages/web/src/app/(app)/settings/layout.tsx @@ -44,7 +44,7 @@ export default async function SettingsLayout( } export const getSidebarNavGroups = async () => - withAuth(async ({ role }) => { + withAuth(async ({ org, role, prisma }) => { let numJoinRequests: number | undefined; if (role === OrgRole.OWNER) { const requests = await getOrgAccountRequests(); @@ -58,6 +58,12 @@ export const getSidebarNavGroups = async () => if (isServiceError(connectionStats)) { throw new ServiceErrorException(connectionStats); } + const hasOAuthEntitlement = await hasEntitlement("oauth"); + const hasApprovedConnectors = role === OrgRole.OWNER && !hasOAuthEntitlement + ? await prisma.mcpServer.count({ + where: { orgId: org.id }, + }) > 0 + : false; const groups: NavGroup[] = [ { @@ -86,7 +92,14 @@ export const getSidebarNavGroups = async () => title: "MCP Server", href: `/settings/mcp`, icon: 'mcp' as const, - } + }, + ...(hasOAuthEntitlement ? [ + { + title: "Ask Agent", + href: `/settings/accountAskAgent`, + icon: "bot" as const, + } + ] : []), ], }, ]; @@ -119,6 +132,13 @@ export const getSidebarNavGroups = async () => icon: "chart-area" as const, requiredEntitlement: 'analytics' }, + ...(hasOAuthEntitlement || hasApprovedConnectors ? [ + { + title: "Ask Agent", + href: `/settings/workspaceAskAgent`, + icon: "bot" as const, + } + ] : []), { title: "License", href: `/settings/license`, @@ -129,4 +149,4 @@ export const getSidebarNavGroups = async () => } return groups.filter(g => g.items.length > 0); - }); \ No newline at end of file + }); diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx new file mode 100644 index 000000000..8d5be7a56 --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx @@ -0,0 +1,67 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, render, screen } from '@testing-library/react'; +import type React from 'react'; + +const mocks = vi.hoisted(() => ({ + authContext: { + org: { id: 1 }, + prisma: { + mcpServer: { + count: vi.fn(), + }, + }, + }, + hasEntitlement: vi.fn(), +})); + +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/authenticatedPage', () => ({ + authenticatedPage: vi.fn((page: (auth: typeof mocks.authContext, props: { searchParams: Promise> }) => Promise) => + (props: { searchParams: Promise> }) => page(mocks.authContext, props)), +})); +vi.mock('./workspaceAskAgentPage', () => ({ + WorkspaceAskAgentPage: () =>
Workspace Ask Agent client
, +})); + +const { default: Page } = await import('./page'); + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.authContext.prisma.mcpServer.count.mockResolvedValue(0); +}); + +afterEach(() => { + cleanup(); +}); + +describe('Ask Agent settings page', () => { + test('renders the client configuration page when OAuth is available', async () => { + render(await Page({ searchParams: Promise.resolve({}) })); + + expect(screen.getByText('Workspace Ask Agent client')).toBeTruthy(); + }); + + test('renders the client configuration page when OAuth is unavailable but servers exist for cleanup', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + mocks.authContext.prisma.mcpServer.count.mockResolvedValue(1); + + render(await Page({ searchParams: Promise.resolve({}) })); + + expect(screen.getByText('Workspace Ask Agent client')).toBeTruthy(); + expect(mocks.authContext.prisma.mcpServer.count).toHaveBeenCalledWith({ + where: { orgId: 1 }, + }); + }); + + test('renders an unavailable message when OAuth is not available and no cleanup is needed', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + + render(await Page({ searchParams: Promise.resolve({}) })); + + expect(screen.getByText('Ask Agent Connectors Are Unavailable')).toBeTruthy(); + expect(screen.queryByText('Workspace Ask Agent client')).toBeNull(); + }); +}); diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx new file mode 100644 index 000000000..1b4eeef14 --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx @@ -0,0 +1,35 @@ +import { hasEntitlement } from "@/lib/entitlements"; +import { authenticatedPage } from "@/middleware/authenticatedPage"; +import { OrgRole } from "@sourcebot/db"; +import { WorkspaceAskAgentPage } from "./workspaceAskAgentPage"; +import { WorkspaceAskAgentUnavailableMessage } from "./workspaceAskAgentUnavailableMessage"; + +interface PageProps extends Record { + searchParams: Promise<{ + status?: string; + server?: string; + message?: string; + }>; +} + +export default authenticatedPage(async ({ org, prisma }, { searchParams }) => { + if (!(await hasEntitlement("oauth"))) { + const serverCount = await prisma.mcpServer.count({ + where: { orgId: org.id }, + }); + + if (serverCount === 0) { + return ; + } + } + + const { status, server, message } = await searchParams; + + return ( + + ); +}, { minRole: OrgRole.OWNER, redirectTo: '/settings' }); diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/prefabConnectorPopover.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/prefabConnectorPopover.tsx new file mode 100644 index 000000000..479b8c09e --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/prefabConnectorPopover.tsx @@ -0,0 +1,129 @@ +'use client'; + +import { useMemo, useState } from "react"; +import { + Command, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + CommandSeparator, +} from "@/components/ui/command"; +import { Button } from "@/components/ui/button"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { getDisplayServerUrl } from "@/ee/features/chat/mcp/components/connectorRowInfo"; +import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; +import { + getAvailablePrefabMcpServers, + type PrefabMcpServer, +} from "@/ee/features/chat/mcp/prefabMcpServers"; +import { getMcpFaviconUrl } from "@/ee/features/chat/mcp/utils"; +import { PlusIcon } from "lucide-react"; + +interface PrefabConnectorPopoverProps { + configuredServerUrls: string[]; + disabled?: boolean; + onSelectCustomUrl: () => void; + onSelectPrefabServer: (server: PrefabMcpServer) => void; + children?: React.ReactNode; +} + +export function PrefabConnectorPopover({ + configuredServerUrls, + disabled, + onSelectCustomUrl, + onSelectPrefabServer, + children, +}: PrefabConnectorPopoverProps) { + const [isOpen, setIsOpen] = useState(false); + const [search, setSearch] = useState(""); + + const availablePrefabServers = useMemo(() => ( + getAvailablePrefabMcpServers(configuredServerUrls) + ), [configuredServerUrls]); + + const filteredPrefabServers = useMemo(() => { + const normalizedSearch = search.trim().toLowerCase(); + + if (!normalizedSearch) { + return availablePrefabServers; + } + + return availablePrefabServers.filter((server) => server.name.toLowerCase().includes(normalizedSearch)); + }, [availablePrefabServers, search]); + + const handleOpenChange = (open: boolean) => { + setIsOpen(open); + + if (!open) { + setSearch(""); + } + }; + + const handleSelectPrefabServer = (server: PrefabMcpServer) => { + handleOpenChange(false); + onSelectPrefabServer(server); + }; + + const handleSelectCustomUrl = () => { + handleOpenChange(false); + onSelectCustomUrl(); + }; + + return ( + + + {children ?? ( + + )} + + + + + + + {filteredPrefabServers.map((server) => ( + handleSelectPrefabServer(server)} + className="cursor-pointer" + > +
+ +
+
+

{server.name}

+

{getDisplayServerUrl(server.serverUrl)}

+
+
+ ))} + {search.trim() && filteredPrefabServers.length === 0 && ( +
+ No connectors found. +
+ )} +
+ + + + + Custom URL... + + +
+
+
+
+ ); +} diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx new file mode 100644 index 000000000..18fbc4411 --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx @@ -0,0 +1,619 @@ +'use client'; + +import { useEffect, useMemo, useRef, useState } from "react"; +import { getMcpConfiguration, getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { + Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle, +} from "@/components/ui/dialog"; +import { + DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { checkMcpServerDynamicClientRegistration, createMcpServer, createStaticOAuthMcpServer, deleteMcpServer } from "@/ee/features/chat/mcp/actions"; +import { ConnectMcpButton } from "@/ee/features/chat/mcp/components/connectMcpButton"; +import { ConnectorCard } from "@/ee/features/chat/mcp/components/connectorCard"; +import { useMcpToolMetadata } from "@/ee/features/chat/mcp/hooks/useMcpToolMetadata"; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; +import { pluralize } from "@/ee/features/chat/mcp/utils"; +import { cn, isServiceError } from "@/lib/utils"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { AlertTriangleIcon, CableIcon, CopyIcon, Loader2, MoreHorizontalIcon, PlusIcon, Trash2Icon } from "lucide-react"; +import { PrefabConnectorPopover } from "./prefabConnectorPopover"; +import type { PrefabMcpServer } from "@/ee/features/chat/mcp/prefabMcpServers"; +import type { McpConfigurationServer, ServerToolsEntry } from "@/ee/features/chat/mcp/types"; + +function clearCallbackParams() { + const url = new URL(window.location.href); + url.searchParams.delete('status'); + url.searchParams.delete('server'); + url.searchParams.delete('message'); + window.history.replaceState({}, '', url.toString()); +} + +interface WorkspaceAskAgentPageProps { + callbackStatus?: string; + callbackServer?: string; + callbackMessage?: string; +} + +type WorkspaceConnectorStatus = { + isConnected: boolean; + isAuthExpired: boolean; +}; + +interface WorkspaceConnectorCardProps { + server: McpConfigurationServer; + status?: WorkspaceConnectorStatus; + isOAuthAvailable: boolean; + isStatusLoading: boolean; + isStatusError: boolean; + toolEntry?: ServerToolsEntry; + isToolsLoading: boolean; + isToolsError: boolean; + onRetryTools: () => void; + onCopyUrl: (serverUrl: string) => void; + onDelete: (server: McpConfigurationServer) => void; +} + +function WorkspaceConnectorCard({ + server, + status, + isOAuthAvailable, + isStatusLoading, + isStatusError, + toolEntry, + isToolsLoading, + isToolsError, + onRetryTools, + onCopyUrl, + onDelete, +}: WorkspaceConnectorCardProps) { + const isConnected = status?.isConnected === true; + const isAuthExpired = status?.isAuthExpired === true; + const isStatusUnavailable = isOAuthAvailable !== true || isStatusLoading || isStatusError || !status; + const showConnectButton = isOAuthAvailable && !isStatusLoading && !isStatusError && !!status && !isConnected; + const serverLabel = server.name || server.serverUrl; + + return ( + 0 ? "text-green-600 dark:text-green-400" : "text-muted-foreground", + )}> + 0 ? "bg-green-500/80" : "bg-muted-foreground", + )} /> + {server.savedConnectionCount > 0 + ? `${server.savedConnectionCount} ${pluralize(server.savedConnectionCount, "member")} connected` + : "No members connected"} + + } + actionButtons={ + <> + {showConnectButton && ( + + )} + + + + + + onCopyUrl(server.serverUrl)}> + + Copy URL + + onDelete(server)} + aria-label={`Remove ${serverLabel}`} + > + + Remove + + + + + } + /> + ); +} + +export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callbackMessage }: WorkspaceAskAgentPageProps) { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const didHandleCallbackRef = useRef(false); + + useEffect(() => { + if (didHandleCallbackRef.current) { + return; + } + if (callbackStatus === 'connected') { + didHandleCallbackRef.current = true; + toast({ description: `Successfully connected${callbackServer ? ` to ${callbackServer}` : ''}.` }); + clearCallbackParams(); + } else if (callbackStatus === 'error') { + didHandleCallbackRef.current = true; + toast({ title: "Connection failed", description: callbackMessage ?? 'Failed to connect connector.', variant: "destructive" }); + clearCallbackParams(); + } + }, [callbackStatus, callbackServer, callbackMessage, toast]); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [newServerName, setNewServerName] = useState(""); + const [newServerUrl, setNewServerUrl] = useState(""); + const [isClientCredentialsDialogOpen, setIsClientCredentialsDialogOpen] = useState(false); + const [pendingClientCredentialsServer, setPendingClientCredentialsServer] = useState<{ name: string; serverUrl: string } | null>(null); + const [clientId, setClientId] = useState(""); + const [clientSecret, setClientSecret] = useState(""); + const [isCreating, setIsCreating] = useState(false); + const [deletingServerId, setDeletingServerId] = useState(null); + const [serverToDelete, setServerToDelete] = useState(null); + + const { data, isLoading, isError } = useQuery({ + queryKey: mcpQueryKeys.configuration, + queryFn: async () => { + const result = await getMcpConfiguration(); + if (isServiceError(result)) { + throw new Error(result.message); + } + return result; + }, + }); + + const { data: serversWithStatus, isLoading: isServersWithStatusLoading, isError: isServersWithStatusError } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load connector status"); + } + if (!Array.isArray(result)) { + throw new Error("Unexpected response from connector status endpoint"); + } + return result; + }, + enabled: data?.isOAuthAvailable !== false, + }); + + const myStatusByServerId = useMemo(() => { + const map = new Map(); + for (const s of serversWithStatus ?? []) { + map.set(s.id, { isConnected: s.isConnected, isAuthExpired: s.isAuthExpired }); + } + return map; + }, [serversWithStatus]); + + const servers = data?.servers ?? []; + const canCreateConnectors = data?.isOAuthAvailable === true; + const isOAuthUnavailable = data?.isOAuthAvailable === false; + const connectedServerCount = useMemo( + () => serversWithStatus?.filter((server) => server.isConnected).length ?? 0, + [serversWithStatus], + ); + const { + isToolsLoading, + isToolsError, + refetchTools, + toolsByServerId, + } = useMcpToolMetadata(data?.isOAuthAvailable === true, connectedServerCount); + + const handleCreateDialogOpenChange = (open: boolean) => { + setIsCreateDialogOpen(open); + + if (!open) { + setNewServerName(""); + setNewServerUrl(""); + } + }; + + const handleCloseCreateDialog = () => { + handleCreateDialogOpenChange(false); + }; + + const handleCloseClientCredentialsDialog = () => { + setIsClientCredentialsDialogOpen(false); + setPendingClientCredentialsServer(null); + setClientId(""); + setClientSecret(""); + }; + + const handleOpenCustomUrlDialog = () => { + setNewServerName(""); + setNewServerUrl(""); + setIsCreateDialogOpen(true); + }; + + const handleCreateStaticOAuthServer = async () => { + if (!pendingClientCredentialsServer) { + toast({ title: "Error", description: "Missing connector details", variant: "destructive" }); + return; + } + + if (process.env.NODE_ENV === "production" && window.location.protocol !== "https:") { + toast({ + title: "HTTPS required", + description: "Static OAuth client credentials can only be submitted over HTTPS in production.", + variant: "destructive", + }); + return; + } + + setIsCreating(true); + try { + const result = await createStaticOAuthMcpServer({ + name: pendingClientCredentialsServer.name, + serverUrl: pendingClientCredentialsServer.serverUrl, + clientId, + clientSecret, + }); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to add connector: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + handleCloseClientCredentialsDialog(); + } catch { + toast({ title: "Error", description: "Failed to add connector.", variant: "destructive" }); + } finally { + setIsCreating(false); + } + }; + + const handleCreateServer = async ( + name: string, + serverUrl: string, + onSuccess?: () => void, + options: { checkDynamicClientRegistration?: boolean } = {}, + ) => { + const displayName = name.trim(); + const normalizedServerUrl = serverUrl.trim(); + + if (!displayName || !normalizedServerUrl) { + toast({ title: "Error", description: "Name and connector URL are required", variant: "destructive" }); + return; + } + + setIsCreating(true); + try { + if (options.checkDynamicClientRegistration) { + const dcrSupport = await checkMcpServerDynamicClientRegistration(normalizedServerUrl); + if (isServiceError(dcrSupport)) { + toast({ title: "Error", description: `Failed to check connector: ${dcrSupport.message}`, variant: "destructive" }); + return; + } + + if (dcrSupport.isKnown && !dcrSupport.supportsDcr) { + setPendingClientCredentialsServer({ name: displayName, serverUrl: normalizedServerUrl }); + setIsCreateDialogOpen(false); + setIsClientCredentialsDialogOpen(true); + return; + } + } + + const result = await createMcpServer(displayName, normalizedServerUrl); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to add connector: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + onSuccess?.(); + } catch (error) { + toast({ title: "Error", description: `Failed to add connector: ${error}`, variant: "destructive" }); + } finally { + setIsCreating(false); + } + }; + + const handleCreate = async () => { + await handleCreateServer(newServerName, newServerUrl, handleCloseCreateDialog, { + checkDynamicClientRegistration: true, + }); + }; + + const handleCreatePrefabServer = async (server: PrefabMcpServer) => { + await handleCreateServer(server.name, server.serverUrl, undefined, { + checkDynamicClientRegistration: true, + }); + }; + + const handleDelete = async (serverId: string) => { + setDeletingServerId(serverId); + try { + const result = await deleteMcpServer(serverId); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to remove connector: ${result.message}`, variant: "destructive" }); + return; + } + + await invalidateMcpConfigurationQueries(queryClient); + setServerToDelete(null); + } catch (error) { + toast({ title: "Error", description: `Failed to remove connector: ${error}`, variant: "destructive" }); + } finally { + setDeletingServerId(null); + } + }; + + const handleCopyUrl = (serverUrl: string) => { + navigator.clipboard.writeText(serverUrl); + toast({ title: "Copied", description: "Connector URL copied to clipboard." }); + }; + + if (isError) { + return
Error loading Ask Agent settings
; + } + + const prefabPopoverProps = { + configuredServerUrls: servers.map((s) => s.serverUrl), + disabled: isCreating, + onSelectCustomUrl: handleOpenCustomUrlDialog, + onSelectPrefabServer: handleCreatePrefabServer, + }; + + return ( +
+ {/* Page header */} +
+

Ask Agent

+

+ Configure what external tools Ask Agent can use across this workspace. +

+
+ + + + {/* OAuth unavailable warning */} + {!isLoading && isOAuthUnavailable && ( +
+ +
+

Connector OAuth is unavailable

+

+ You can remove existing approved connectors and stored credentials, but cannot add new connectors. +

+
+
+ )} + + {/* Connectors section */} +
+
+

Connectors

+

+ Connectors are MCP servers that let Ask Agent use approved external tools alongside your indexed code. +

+
+ + {/* Allowed connectors subsection */} +
+
+
+

Allowed connectors

+

+ {isOAuthUnavailable + ? "Remove existing connector approvals and their stored credentials." + : "Approve connector URLs that workspace members can connect to."} +

+
+ {canCreateConnectors && ( + + + + )} +
+ + {/* Connector list */} +
+ {isLoading ? ( + Array.from({ length: 3 }).map((_, i) => ( + + + +
+ + +
+ +
+
+ )) + ) : servers.length === 0 ? ( + + +
+ +
+

No connectors configured yet

+

+ {isOAuthUnavailable + ? "Connector OAuth is unavailable on this Sourcebot instance." + : "Add a workspace-approved connector so members can use it with Ask Agent."} +

+
+
+ ) : ( + servers.map((server) => ( + { void refetchTools(); }} + onCopyUrl={handleCopyUrl} + onDelete={setServerToDelete} + /> + )) + )} +
+
+
+ + {/* Delete confirmation */} + { if (!open) { setServerToDelete(null); } }}> + + + Remove Connector + + Are you sure you want to remove {serverToDelete?.name || serverToDelete?.serverUrl}? Workspace members will lose access and stored credentials for this connector. + + + + Cancel + { if (serverToDelete) { handleDelete(serverToDelete.id); } }} + disabled={deletingServerId !== null} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + {deletingServerId ? "Removing..." : "Remove"} + + + + + + {/* Add connector dialog */} + + + + Add Connector + + Add a workspace-approved connector that members can use with Ask Agent. + + +
+
+ + setNewServerName(event.target.value)} + placeholder="e.g. Linear" + /> +
+
+ + setNewServerUrl(event.target.value)} + placeholder="https://mcp.linear.app/mcp" + /> +
+
+ + + + +
+
+ + {/* OAuth client credentials dialog */} + { + if (!open) { + handleCloseClientCredentialsDialog(); + return; + } + + setIsClientCredentialsDialogOpen(true); + }}> + + + OAuth Client Credentials Required + + This connector does not advertise dynamic client registration. Provide OAuth client credentials from a pre-registered app before members can connect to it. + + +
+ {pendingClientCredentialsServer && ( +
+

{pendingClientCredentialsServer.name}

+

{pendingClientCredentialsServer.serverUrl}

+
+ )} +
+ + setClientId(event.target.value)} + placeholder="OAuth client ID" + /> +
+
+ + setClientSecret(event.target.value)} + placeholder="OAuth client secret" + /> +
+
+ + + + +
+
+
+ ); +} diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentUnavailableMessage.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentUnavailableMessage.tsx new file mode 100644 index 000000000..21ac97209 --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentUnavailableMessage.tsx @@ -0,0 +1,29 @@ +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { CableIcon } from "lucide-react"; + +export function WorkspaceAskAgentUnavailableMessage() { + return ( +
+ + +
+
+ +
+
+ + Ask Agent Connectors Are Unavailable + + + OAuth-backed connectors are not supported on this Sourcebot instance. + +
+ +

+ Use Sourcebot API keys for agent access on this deployment. +

+
+
+
+ ); +} diff --git a/packages/web/src/app/api/(client)/client.ts b/packages/web/src/app/api/(client)/client.ts index 0fd119050..ecc95818a 100644 --- a/packages/web/src/app/api/(client)/client.ts +++ b/packages/web/src/app/api/(client)/client.ts @@ -31,6 +31,9 @@ import type { SearchChatShareableMembersResponse, } from "../(server)/ee/chat/[chatId]/searchMembers/route"; import { OffersResponse } from "@/ee/features/lighthouse/types"; +import { ConnectMcpResponse } from "../(server)/ee/askmcp/connect/types"; +import type { GetMcpServersResponse } from "../(server)/ee/askmcp/servers/route"; +import type { GetMcpConfigurationResponse, GetMcpToolsResponse } from "@/ee/features/chat/mcp/types"; export const search = async (body: SearchRequest): Promise => { const result = await fetch("/api/search", { @@ -240,4 +243,55 @@ export const getOffers = async (): Promise => { }).then(response => response.json()); return result as OffersResponse | ServiceError; -} \ No newline at end of file +} + +export const connectMcpToAsk = async (body: { serverId: string; returnTo?: string }): Promise => { + const result = await fetch('/api/ee/askmcp/connect', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: JSON.stringify(body), + }).then(response => response.json()); + + if (isServiceError(result)) { + return result; + } + + return result as ConnectMcpResponse; +} + +export const getMcpServersWithStatus = async (): Promise => { + const result = await fetch('/api/ee/askmcp/servers', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpServersResponse | ServiceError; +} + +export const getMcpConfiguration = async (): Promise => { + const result = await fetch('/api/ee/askmcp/configuration', { + method: 'GET', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpConfigurationResponse | ServiceError; +} + +export const getMcpServerTools = async (): Promise => { + const result = await fetch('/api/ee/askmcp/tools', { + method: 'GET', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpToolsResponse | ServiceError; +} diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index 4c0b12819..77379457d 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -1,4 +1,5 @@ import { sew } from "@/middleware/sew"; +import { getAskMcpAvailabilityAnalytics, getAskMcpTurnCompletedAnalytics } from "@/features/chat/askMcpAnalytics.server"; import { createMessageStream } from "@/features/chat/agent"; import { additionalChatRequestParamsSchema } from "@/features/chat/types"; import { getLanguageModelKey } from "@/features/chat/utils"; @@ -33,7 +34,7 @@ export const POST = apiHandler(async (req: NextRequest) => { return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); } - const { messages, id, selectedSearchScopes, languageModel: _languageModel } = parsed.data; + const { messages, id, selectedSearchScopes, disabledMcpServerIds, languageModel: _languageModel } = parsed.data; // @note: a bit of type massaging is required here since the // zod schema does not enum on `model` or `provider`. // @see: chat/types.ts @@ -92,12 +93,20 @@ export const POST = apiHandler(async (req: NextRequest) => { }))).flat(); const source = req.headers.get('X-Sourcebot-Client-Source') ?? undefined; + const askMcpSource = source === 'sourcebot-web-client' ? source : undefined; + const askMcpAvailability = await getAskMcpAvailabilityAnalytics({ + prisma, + userId: user?.id, + orgId: org.id, + disabledMcpServerIds, + }); await captureEvent('ask_message_sent', { chatId: id, messageCount: messages.length, selectedReposCount: expandedRepos.length, source, + ...askMcpAvailability, ...(env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? { selectedRepos: expandedRepos } : {}), }); @@ -108,12 +117,27 @@ export const POST = apiHandler(async (req: NextRequest) => { selectedSearchScopes, }, selectedRepos: expandedRepos, + prisma, + disabledMcpServerIds, model, modelName: languageModelConfig.displayName ?? languageModelConfig.model, modelProviderOptions: providerOptions, modelTemperature: temperature, + userId: user?.id, + orgId: org.id, onFinish: async ({ messages }) => { await updateChatMessages({ chatId: id, messages, prisma }); + const askMcpTurnCompleted = getAskMcpTurnCompletedAnalytics({ + messages, + availability: askMcpAvailability, + }); + if (askMcpTurnCompleted) { + void captureEvent('ask_mcp_turn_completed', { + chatId: id, + source: askMcpSource, + ...askMcpTurnCompleted, + }); + } }, onError: (error: unknown) => { logger.error(error); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts new file mode 100644 index 000000000..430cef51d --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts @@ -0,0 +1,211 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { McpServerClientInfoSource } from '@sourcebot/db'; + +const mocks = vi.hoisted(() => ({ + auth: vi.fn(), + hasEntitlement: vi.fn(), + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + mcpAuth: vi.fn(), + unsafePrisma: { + mcpServer: { + updateMany: vi.fn(), + }, + userMcpServer: { + findFirst: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + userToOrg: { + findUnique: vi.fn(), + }, + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/auth', () => ({ + auth: mocks.auth, +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/prisma', () => ({ + prisma: mocks.unsafePrisma, + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + env: { + AUTH_URL: 'https://sourcebot.example.com', + }, + createLogger: () => mocks.logger, + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), +})); +vi.mock('@ai-sdk/mcp', () => ({ + auth: mocks.mcpAuth, +})); + +const { GET } = await import('./route'); +const { createMcpOAuthState } = await import('@/features/mcp/mcpOAuthReturnTo'); + +function createRequest(state = 'state-1') { + return new NextRequest(`https://sourcebot.example.com/api/ee/askmcp/callback?code=code-1&state=${encodeURIComponent(state)}`, { + method: 'GET', + }); +} + +function createOAuthErrorRequest(state: string) { + return new NextRequest(`https://sourcebot.example.com/api/ee/askmcp/callback?error=access_denied&error_description=Denied&state=${encodeURIComponent(state)}`, { + method: 'GET', + }); +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.auth.mockResolvedValue({ user: { id: 'user-1' } }); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.unsafePrisma.userMcpServer.findFirst.mockResolvedValue({ + serverId: 'server-1', + server: { + orgId: 1, + name: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + mocks.unsafePrisma.userMcpServer.update.mockResolvedValue({ userId: 'user-1', serverId: 'server-1' }); + mocks.unsafePrisma.userToOrg.findUnique.mockResolvedValue({ orgId: 1, userId: 'user-1' }); +}); + +describe('GET /api/ee/askmcp/callback', () => { + test('redirects successful chat-originated auth back to chat', async () => { + const state = createMcpOAuthState('state-1', '/chat'); + mocks.mcpAuth.mockResolvedValue('AUTHORIZED'); + + const response = await GET(createRequest(state)); + const location = response.headers.get('location'); + const url = new URL(location ?? ''); + + expect(url.pathname).toBe('/chat'); + expect(url.searchParams.get('status')).toBe('connected'); + expect(url.searchParams.get('server')).toBe('Linear'); + expect(mocks.unsafePrisma.userMcpServer.findFirst).toHaveBeenCalledWith({ + where: { + state, + userId: 'user-1', + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + sanitizedName: true, + clientInfoSource: true, + }, + }, + }, + }); + }); + + test('redirects denied chat-originated auth back to chat', async () => { + const state = createMcpOAuthState('state-1', '/chat'); + + const response = await GET(createOAuthErrorRequest(state)); + const url = new URL(response.headers.get('location') ?? ''); + + expect(url.pathname).toBe('/chat'); + expect(url.searchParams.get('status')).toBe('error'); + expect(url.searchParams.get('message')).toBe('Denied'); + expect(mocks.mcpAuth).not.toHaveBeenCalled(); + }); + + test('redirects with a friendly reconnect error when callback auth cannot complete', async () => { + mocks.mcpAuth.mockImplementation(async (provider) => { + expect('saveClientInformation' in provider).toBe(false); + await provider.invalidateCredentials('all'); + const error = new Error('invalid_client client_secret=client-secret refresh_token=refresh-token'); + Object.assign(error, { + response: { + status: 401, + body: 'client_secret=client-secret refresh_token=refresh-token', + }, + }); + throw error; + }); + + const response = await GET(createRequest()); + const location = response.headers.get('location'); + + expect(location).toBeTruthy(); + expect(location).toContain('/settings/accountAskAgent'); + expect(location).toContain('status=error'); + expect(new URL(location ?? '').searchParams.get('message')).toContain('Please reconnect the connector'); + expect(mocks.unsafePrisma.userMcpServer.findFirst).toHaveBeenCalledWith({ + where: { + state: 'state-1', + userId: 'user-1', + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + sanitizedName: true, + clientInfoSource: true, + }, + }, + }, + }); + expect(mocks.unsafePrisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + expect(mocks.logger.warn).toHaveBeenCalledWith('Failed to authorize MCP server.', { + serverId: 'server-1', + orgId: 1, + error: { + errorClass: 'Error', + oauthError: 'invalid_client', + statusCode: 401, + }, + }); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('refresh-token'); + }); + + test('clears verifier state when callback auth throws before provider cleanup', async () => { + mocks.mcpAuth.mockRejectedValue(new Error('token exchange failed')); + + const response = await GET(createRequest()); + const location = response.headers.get('location'); + + expect(location).toContain('status=error'); + expect(mocks.unsafePrisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts new file mode 100644 index 000000000..c2d15704e --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts @@ -0,0 +1,221 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { env, createLogger } from '@sourcebot/shared'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +// Note: We use the raw (unscoped) prisma client here because this route handles OAuth +// redirect callbacks from external providers, so it can't go through withAuth. Session +// identity is verified via NextAuth's auth() instead, and all queries filter by userId. +import { __unsafePrisma as prisma } from '@/prisma'; +import { auth } from '@/auth'; +import { NextRequest, NextResponse } from 'next/server'; +import { getExternalMcpErrorLogFields } from '@/ee/features/chat/mcp/externalMcpError'; +import { getMcpOAuthReturnToFromState } from '@/features/mcp/mcpOAuthReturnTo'; +import { captureEvent } from '@/lib/posthog'; +import { getMcpAuthMode, getMcpConnectorEntryPoint, getMcpConnectorFailureReason } from '@/ee/features/chat/mcp/analytics'; + +const logger = createLogger('mcp-oauth-callback'); +const reconnectMessage = 'This connector authorization could not be completed. Please reconnect the connector.'; +const defaultMcpOAuthReturnTo = '/settings/accountAskAgent'; + +function createMcpOAuthRedirectUrl(returnTo: string | undefined): URL { + return new URL(returnTo ?? defaultMcpOAuthReturnTo, env.AUTH_URL); +} + +function setMcpOAuthStatusParams(url: URL, params: { status: 'connected' | 'error'; server?: string; message?: string }) { + url.searchParams.set('status', params.status); + if (params.server) { + url.searchParams.set('server', params.server); + } + if (params.message) { + url.searchParams.set('message', params.message); + } +} + +function redirectToCallbackError(message: string, returnTo?: string) { + const url = createMcpOAuthRedirectUrl(returnTo); + setMcpOAuthStatusParams(url, { status: 'error', message }); + return NextResponse.redirect(url); +} + +// eslint-disable-next-line authz/require-auth-wrapper -- OAuth redirect callback validates the active session with auth() and filters all queries by userId. +export const GET = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const session = await auth(); + if (!session?.user?.id) { + return Response.json( + { error: 'unauthorized', error_description: 'You must be logged in.' }, + { status: 401 } + ); + } + + const { searchParams } = request.nextUrl; + const oauthError = searchParams.get('error'); + const code = searchParams.get('code'); + const state = searchParams.get('state'); + const callbackReturnTo = getMcpOAuthReturnToFromState(state); + const entryPoint = getMcpConnectorEntryPoint(callbackReturnTo); + const getUserServer = () => { + if (!state) { + return Promise.resolve(null); + } + + return prisma.userMcpServer.findFirst({ + where: { + state, + userId: session.user.id, + }, + select: { + serverId: true, + server: { + select: { + orgId: true, + name: true, + serverUrl: true, + sanitizedName: true, + clientInfoSource: true, + }, + }, + }, + }); + }; + const createEventProperties = (userServer: NonNullable>>) => ({ + source: 'sourcebot-web-client' as const, + entryPoint, + serverId: userServer.serverId, + serverName: userServer.server.name, + serverUrl: userServer.server.serverUrl, + sanitizedName: userServer.server.sanitizedName, + authMode: getMcpAuthMode(userServer.server.clientInfoSource), + }); + const getEventProperties = async () => { + const userServer = await getUserServer(); + return userServer ? createEventProperties(userServer) : undefined; + }; + + // Handle OAuth errors (e.g., user cancelled the authorization flow). + if (oauthError) { + // Error callbacks often have no authorization code, so fetch the pending connector here + // only to enrich cancellation/denial analytics when the provider returned state. + const eventProperties = await getEventProperties(); + if (eventProperties) { + void captureEvent('ask_mcp_connector_connection_failed', { + ...eventProperties, + failureReason: 'oauth_error', + }); + } + const url = createMcpOAuthRedirectUrl(callbackReturnTo); + const errorDescription = searchParams.get('error_description') ?? 'Authorization was cancelled or denied.'; + setMcpOAuthStatusParams(url, { status: 'error', message: errorDescription }); + return NextResponse.redirect(url); + } + + if (!code || !state) { + void captureEvent('ask_mcp_connector_connection_failed', { + source: 'sourcebot-web-client', + entryPoint, + failureReason: 'invalid_request', + }); + return Response.json( + { error: 'invalid_request', error_description: 'Missing required parameters: code, state.' }, + { status: 400 } + ); + } + + const userServer = await getUserServer(); + if (!userServer) { + void captureEvent('ask_mcp_connector_connection_failed', { + source: 'sourcebot-web-client', + entryPoint, + failureReason: 'invalid_state', + }); + return Response.json( + { error: 'invalid_state', error_description: 'No pending authorization found for this state.' }, + { status: 400 } + ); + } + + const connectorEventProperties = createEventProperties(userServer); + + const orgMembership = await prisma.userToOrg.findUnique({ + where: { + orgId_userId: { + orgId: userServer.server.orgId, + userId: session.user.id, + }, + }, + }); + + if (!orgMembership) { + void captureEvent('ask_mcp_connector_connection_failed', { + ...connectorEventProperties, + failureReason: 'forbidden', + }); + return Response.json( + { error: 'forbidden', error_description: 'You do not have access to this connector.' }, + { status: 403 } + ); + } + + const provider = new PrismaOAuthClientProvider({ + prisma, + serverId: userServer.serverId, + orgId: userServer.server.orgId, + userId: session.user.id, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + }); + + let result: Awaited>; + + try { + result = await mcpAuth(provider, { + serverUrl: new URL(userServer.server.serverUrl), + authorizationCode: code, + callbackState: state, + }); + } catch (error) { + logger.warn('Failed to authorize MCP server.', { + serverId: userServer.serverId, + orgId: userServer.server.orgId, + error: getExternalMcpErrorLogFields(error), + }); + void captureEvent('ask_mcp_connector_connection_failed', { + ...connectorEventProperties, + failureReason: getMcpConnectorFailureReason(error), + }); + return redirectToCallbackError(reconnectMessage, callbackReturnTo); + } finally { + // Always clear ephemeral PKCE/state regardless of outcome to prevent replay. + try { + await provider.invalidateCredentials('verifier'); + } catch (cleanupError) { + logger.warn(`Failed to clear MCP OAuth verifier for user ${session.user.id}:`, cleanupError); + } + } + + if (result === 'AUTHORIZED') { + const displayName = userServer.server.name || userServer.server.serverUrl; + logger.info(`Successfully authorized MCP server ${displayName} for user ${session.user.id}.`); + void captureEvent('ask_mcp_connector_connection_completed', { + ...connectorEventProperties, + alreadyAuthorized: false, + }); + const url = createMcpOAuthRedirectUrl(callbackReturnTo); + setMcpOAuthStatusParams(url, { status: 'connected', server: displayName }); + return NextResponse.redirect(url); + } + + // If auth() didn't return AUTHORIZED, something went wrong + void captureEvent('ask_mcp_connector_connection_failed', { + ...connectorEventProperties, + failureReason: 'token_exchange_failed', + }); + return redirectToCallbackError('Token exchange failed', callbackReturnTo); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts new file mode 100644 index 000000000..1c868baae --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts @@ -0,0 +1,278 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { OrgRole } from '@sourcebot/db'; +import { ErrorCode } from '@/lib/errorCodes'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + withAuth: vi.fn(), + unsafePrisma: { + userMcpServer: { + groupBy: vi.fn(), + }, + mcpServerToolCallCount: { + findMany: vi.fn(), + }, + }, +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: mocks.withAuth, +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); + +const { GET } = await import('./route'); + +function createRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/configuration', { method: 'GET' }); +} + +function createPrismaMock() { + return { + mcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + }, + ]), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.withAuth.mockImplementation((callback: (context: unknown) => unknown) => callback(mocks.authContext)); + mocks.unsafePrisma.userMcpServer.groupBy.mockResolvedValue([ + { + serverId: 'server-1', + _count: { _all: 2 }, + }, + ]); + mocks.unsafePrisma.mcpServerToolCallCount.findMany.mockResolvedValue([ + { + mcpServerId: 'server-1', + toolName: 'search_issues', + count: 5, + }, + { + mcpServerId: 'server-1', + toolName: 'get_issue', + count: 3, + }, + { + mcpServerId: 'server-2', + toolName: 'list_projects', + count: 2, + }, + ]); +}); + +describe('GET /api/ee/askmcp/configuration', () => { + test('lists approved servers with current-member saved connection counts', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(prisma.mcpServer.findMany).toHaveBeenCalledWith({ + where: { orgId: 1 }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + expect(mocks.unsafePrisma.userMcpServer.groupBy).toHaveBeenCalledWith({ + by: ['serverId'], + where: { + serverId: { in: ['server-1', 'server-2'] }, + tokens: { not: null }, + server: { orgId: 1 }, + user: { + orgs: { + some: { orgId: 1 }, + }, + }, + }, + _count: { _all: true }, + }); + expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).toHaveBeenCalledWith({ + where: { + mcpServerId: { in: ['server-1', 'server-2'] }, + mcpServer: { orgId: 1 }, + count: { gt: 0 }, + }, + orderBy: [ + { mcpServerId: 'asc' }, + { count: 'desc' }, + ], + select: { + mcpServerId: true, + toolName: true, + count: true, + }, + }); + expect(body).toMatchObject({ + allowedMode: 'approved_only', + isOAuthAvailable: true, + servers: [ + { + id: 'server-1', + name: 'Linear', + savedConnectionCount: 2, + toolUsage: { + totalCalls: 8, + usedToolCount: 2, + tools: [ + { + toolName: 'search_issues', + totalCalls: 5, + usageSharePercent: 62.5, + }, + { + toolName: 'get_issue', + totalCalls: 3, + usageSharePercent: 37.5, + }, + ], + }, + }, + { + id: 'server-2', + name: 'Sentry', + savedConnectionCount: 0, + toolUsage: { + totalCalls: 2, + usedToolCount: 1, + tools: [ + { + toolName: 'list_projects', + totalCalls: 2, + usageSharePercent: 100, + }, + ], + }, + }, + ], + }); + }); + + test('rejects non-owners before unsafe connector queries', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.MEMBER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(403); + expect(body).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.findMany).not.toHaveBeenCalled(); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).not.toHaveBeenCalled(); + }); + + test('rejects unauthenticated callers before checking OAuth entitlement', async () => { + mocks.withAuth.mockResolvedValue({ + statusCode: 401, + errorCode: ErrorCode.NOT_AUTHENTICATED, + message: 'Not authenticated', + }); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(401); + expect(body).toMatchObject({ + errorCode: ErrorCode.NOT_AUTHENTICATED, + }); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).not.toHaveBeenCalled(); + }); + + test('allows entitled owners to list cleanup data when OAuth is unsupported', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + mocks.hasEntitlement.mockResolvedValue(false); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(200); + expect(body).toMatchObject({ + isOAuthAvailable: false, + servers: [ + { + id: 'server-1', + savedConnectionCount: 2, + }, + { + id: 'server-2', + savedConnectionCount: 0, + }, + ], + }); + expect(mocks.withAuth).toHaveBeenCalled(); + expect(prisma.mcpServer.findMany).toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.groupBy).toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).toHaveBeenCalled(); + }); + + test('skips unsafe connector queries when there are no approved servers', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findMany.mockResolvedValue([]); + mocks.authContext = { + org: { id: 1 }, + role: OrgRole.OWNER, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(mocks.unsafePrisma.userMcpServer.groupBy).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).not.toHaveBeenCalled(); + expect(body).toEqual({ + servers: [], + allowedMode: 'approved_only', + isOAuthAvailable: true, + }); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts new file mode 100644 index 000000000..b4d3949a6 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts @@ -0,0 +1,123 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { hasEntitlement } from '@/lib/entitlements'; +import { withAuth } from '@/middleware/withAuth'; +import { withMinimumOrgRole } from '@/middleware/withMinimumOrgRole'; +import { __unsafePrisma } from '@/prisma'; +import { getMcpFaviconUrl } from '@/ee/features/chat/mcp/utils'; +import type { GetMcpConfigurationResponse, McpServerToolUsageSummary } from '@/ee/features/chat/mcp/types'; +import { OrgRole } from '@sourcebot/db'; +import type { NextRequest } from 'next/server'; + +export const GET = apiHandler(async (_request: NextRequest) => { + const result = await withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { + const isOAuthAvailable = await hasEntitlement('oauth'); + + const orgServers = await prisma.mcpServer.findMany({ + where: { orgId: org.id }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + + const serverIds = orgServers.map((server) => server.id); + const connectionCounts = serverIds.length === 0 + ? [] + : await __unsafePrisma.userMcpServer.groupBy({ + by: ['serverId'], + where: { + serverId: { in: serverIds }, + tokens: { not: null }, + server: { orgId: org.id }, + user: { + orgs: { + some: { orgId: org.id }, + }, + }, + }, + _count: { _all: true }, + }); + const countByServerId = new Map( + connectionCounts.map((row) => [row.serverId, row._count._all]), + ); + + const toolCallCountWhere = { + mcpServerId: { in: serverIds }, + mcpServer: { orgId: org.id }, + count: { gt: 0 }, + }; + const toolCallCountRows = serverIds.length === 0 + ? [] + : await __unsafePrisma.mcpServerToolCallCount.findMany({ + where: toolCallCountWhere, + orderBy: [ + { mcpServerId: 'asc' }, + { count: 'desc' }, + ], + select: { + mcpServerId: true, + toolName: true, + count: true, + }, + }); + const toolUsageByServerId = new Map(); + + for (const row of toolCallCountRows) { + const current = toolUsageByServerId.get(row.mcpServerId) ?? { + totalCalls: 0, + usedToolCount: 0, + tools: [], + }; + + current.totalCalls += row.count; + current.usedToolCount += 1; + current.tools.push({ + toolName: row.toolName, + totalCalls: row.count, + usageSharePercent: 0, + }); + toolUsageByServerId.set(row.mcpServerId, current); + } + + for (const usage of toolUsageByServerId.values()) { + usage.tools = usage.tools.map((tool) => ({ + ...tool, + usageSharePercent: usage.totalCalls > 0 + ? (tool.totalCalls / usage.totalCalls) * 100 + : 0, + })); + } + + const servers = orgServers.map((server) => { + const savedConnectionCount = countByServerId.get(server.id) ?? 0; + return { + ...server, + faviconUrl: getMcpFaviconUrl(server.serverUrl, server.name), + savedConnectionCount, + toolUsage: toolUsageByServerId.get(server.id) ?? { + totalCalls: 0, + usedToolCount: 0, + tools: [], + }, + }; + }); + + return { + servers, + allowedMode: 'approved_only', + isOAuthAvailable, + }; + })); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts new file mode 100644 index 000000000..6b9561ac6 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts @@ -0,0 +1,314 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; +import { McpServerClientInfoSource } from '@sourcebot/db'; +import { ErrorCode } from '@/lib/errorCodes'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, + mcpAuth: vi.fn(), + unsafePrisma: { + $transaction: vi.fn(), + }, + captureEvent: vi.fn(), +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/lib/posthog', () => ({ + captureEvent: mocks.captureEvent, +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + env: { + AUTH_URL: 'https://sourcebot.example.com', + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, + createLogger: () => mocks.logger, + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), +})); +vi.mock('@ai-sdk/mcp', () => ({ + auth: mocks.mcpAuth, +})); + +const { POST } = await import('./route'); +const { getMcpOAuthReturnToFromState } = await import('@/features/mcp/mcpOAuthReturnTo'); + +function createRequest(body: { serverId: string; returnTo?: string } = { serverId: 'server-1' }) { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/connect', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(body), + }); +} + +function createMalformedJsonRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/connect', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: '{"serverId":', + }); +} + +function createTextPlainRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/connect', { + method: 'POST', + headers: { 'content-type': 'text/plain' }, + body: 'server-1', + }); +} + +function createEmptyBodyRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/connect', { + method: 'POST', + }); +} + +function createPrismaMock() { + return { + mcpServer: { + findFirst: vi.fn().mockResolvedValue({ + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }), + }, + userMcpServer: { + upsert: vi.fn().mockResolvedValue({ userId: 'user-1', serverId: 'server-1' }), + }, + }; +} + +function createTransactionMock() { + return { + $queryRaw: vi.fn().mockResolvedValue([{ id: 'server-1' }]), + mcpServer: { + findFirst: vi.fn(), + updateMany: vi.fn().mockResolvedValue({ count: 1 }), + }, + userMcpServer: { + findUnique: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.captureEvent.mockResolvedValue(undefined); +}); + +describe('POST /api/ee/askmcp/connect', () => { + test.each([ + ['malformed JSON', createMalformedJsonRequest], + ['text/plain body', createTextPlainRequest], + ['empty body', createEmptyBodyRequest], + ])('returns a request body validation error for %s', async (_name, createInvalidRequest) => { + const response = await POST(createInvalidRequest()); + const body = await response.json(); + + expect(response.status).toBe(400); + expect(body).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid JSON request body.', + }); + expect(mocks.mcpAuth).not.toHaveBeenCalled(); + }); + + test('upserts a nameless user row and performs DCR-capable auth under a row lock', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider, options) => { + expect('saveClientInformation' in provider).toBe(true); + expect(provider.saveClientInformation).toEqual(expect.any(Function)); + expect(options.fetchFn).toEqual(expect.any(Function)); + + await provider.saveClientInformation({ client_id: 'client-1' }); + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest()); + const body = await response.json(); + + expect(mocks.captureEvent).toHaveBeenCalledWith('ask_mcp_connector_connection_started', { + source: 'sourcebot-web-client', + entryPoint: 'unknown', + serverId: 'server-1', + serverName: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + authMode: 'dynamic', + }); + expect(prisma.userMcpServer.upsert).toHaveBeenCalledWith({ + where: { + userId_serverId: { + userId: 'user-1', + serverId: 'server-1', + }, + }, + create: { + userId: 'user-1', + serverId: 'server-1', + }, + update: {}, + }); + expect(mocks.unsafePrisma.$transaction).toHaveBeenCalledWith( + expect.any(Function), + { + maxWait: 10000, + timeout: 10000, + }, + ); + expect(tx.$queryRaw).toHaveBeenCalledOnce(); + expect(tx.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { id: 'server-1', orgId: 1 }, + data: { + clientInfo: 'encrypted:{"client_id":"client-1"}', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + }); + + test('encodes a safe return path into OAuth state', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider) => { + const state = await provider.state(); + expect(getMcpOAuthReturnToFromState(state)).toBe('/chat'); + await provider.saveState(state); + + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest({ serverId: 'server-1', returnTo: '/chat' })); + const body = await response.json(); + + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + expect(tx.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + state: expect.stringContaining('sourcebot_mcp.'), + }, + }); + }); + + test('ignores unsafe return paths', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async (provider) => { + const state = await provider.state(); + expect(getMcpOAuthReturnToFromState(state)).toBeUndefined(); + await provider.saveState(state); + + provider.authorizationUrl = 'https://oauth.example.com/authorize'; + return 'REDIRECT'; + }); + + const response = await POST(createRequest({ serverId: 'server-1', returnTo: 'https://evil.example.com/chat' })); + const body = await response.json(); + + expect(body).toEqual({ authorizationUrl: 'https://oauth.example.com/authorize' }); + expect(tx.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + state: expect.not.stringContaining('sourcebot_mcp.'), + }, + }); + }); + + test('sanitizes external OAuth errors before logging', async () => { + const prisma = createPrismaMock(); + const tx = createTransactionMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + mocks.unsafePrisma.$transaction.mockImplementation(async (callback, _options) => callback(tx)); + mocks.mcpAuth.mockImplementation(async () => { + const error = new Error('invalid_client client_secret=client-secret refresh_token=refresh-token'); + Object.assign(error, { + response: { + status: 400, + body: 'client_secret=client-secret refresh_token=refresh-token', + }, + }); + throw error; + }); + + const response = await POST(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(502); + expect(body).toMatchObject({ + message: 'Could not start connector authorization.', + }); + expect(mocks.logger.warn).toHaveBeenCalledWith('Failed to start connector authorization.', { + serverId: 'server-1', + orgId: 1, + error: { + errorClass: 'Error', + oauthError: 'invalid_client', + statusCode: 400, + }, + }); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.warn.mock.calls)).not.toContain('refresh-token'); + expect(JSON.stringify(mocks.logger.error.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mocks.logger.error.mock.calls)).not.toContain('refresh-token'); + expect(mocks.captureEvent).toHaveBeenCalledWith('ask_mcp_connector_connection_failed', { + source: 'sourcebot-web-client', + entryPoint: 'unknown', + serverId: 'server-1', + serverName: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + authMode: 'dynamic', + failureReason: 'invalid_client', + }); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts new file mode 100644 index 000000000..a6b3aff02 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts @@ -0,0 +1,200 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { withAuth } from '@/middleware/withAuth'; +import { sew } from '@/middleware/sew'; +import { isServiceError } from '@/lib/utils'; +import { serviceErrorResponse, notFound, requestBodySchemaValidationError, ServiceErrorException } from '@/lib/serviceError'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { NextRequest } from 'next/server'; +import { z } from 'zod'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { ConnectMcpResponse } from "@/app/api/(server)/ee/askmcp/connect/types"; +import { createLogger, env } from "@sourcebot/shared"; +import { __unsafePrisma } from '@/prisma'; +import { getExternalMcpErrorLogFields } from '@/ee/features/chat/mcp/externalMcpError'; +import { ErrorCode } from '@/lib/errorCodes'; +import { StatusCodes } from 'http-status-codes'; +import { normalizeMcpOAuthReturnTo } from '@/features/mcp/mcpOAuthReturnTo'; +import { captureEvent } from '@/lib/posthog'; +import { getMcpAuthMode, getMcpConnectorEntryPoint, getMcpConnectorFailureReason } from '@/ee/features/chat/mcp/analytics'; + +const bodySchema = z.object({ + serverId: z.string(), + returnTo: z.string().optional(), +}); +const logger = createLogger('mcp-connect'); +const MCP_AUTH_FETCH_TIMEOUT_MS = Math.min(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS, 30000); +const MCP_AUTH_TRANSACTION_MAX_WAIT_MS = 10000; +const MCP_AUTH_TRANSACTION_TIMEOUT_MS = MCP_AUTH_FETCH_TIMEOUT_MS + 5000; + +function createTimeoutFetch(timeoutMs: number): typeof fetch { + return async (input, init) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const signal = init?.signal + ? AbortSignal.any([init.signal, timeoutSignal]) + : timeoutSignal; + + return fetch(input, { + ...init, + signal, + }); + }; +} + +export const POST = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + let body: unknown; + try { + body = await request.json(); + } catch { + return serviceErrorResponse({ + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid JSON request body.', + }); + } + + const parsed = bodySchema.safeParse(body); + if (!parsed.success) { + return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); + } + + const result = await sew(() => + withAuth(async ({ user, org, prisma }) => { + const callbackReturnTo = normalizeMcpOAuthReturnTo(parsed.data.returnTo); + const entryPoint = getMcpConnectorEntryPoint(parsed.data.returnTo); + const mcpServer = await prisma.mcpServer.findFirst({ + where: { id: parsed.data.serverId, orgId: org.id }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + clientInfoSource: true, + }, + }); + if (!mcpServer) { + void captureEvent('ask_mcp_connector_connection_failed', { + source: 'sourcebot-web-client', + entryPoint, + serverId: parsed.data.serverId, + failureReason: 'connector_not_found', + }); + return notFound('Connector not found'); + } + + const eventProperties = { + source: 'sourcebot-web-client' as const, + entryPoint, + serverId: mcpServer.id, + serverName: mcpServer.name, + serverUrl: mcpServer.serverUrl, + sanitizedName: mcpServer.sanitizedName, + authMode: getMcpAuthMode(mcpServer.clientInfoSource), + }; + + void captureEvent('ask_mcp_connector_connection_started', eventProperties); + + await prisma.userMcpServer.upsert({ + where: { + userId_serverId: { + userId: user.id, + serverId: mcpServer.id, + }, + }, + create: { + userId: user.id, + serverId: mcpServer.id, + }, + update: {}, + }); + + const connectResult = await __unsafePrisma.$transaction(async (tx) => { + const lockedRows = await tx.$queryRaw<{ id: string }[]>` + SELECT id + FROM "McpServer" + WHERE id = ${mcpServer.id} AND "orgId" = ${org.id} + FOR UPDATE + `; + + if (lockedRows.length === 0) { + throw new ServiceErrorException(notFound('Connector not found')); + } + + const provider = new PrismaOAuthClientProvider({ + prisma: tx, + clientInvalidationPrisma: tx, + serverId: mcpServer.id, + orgId: org.id, + userId: user.id, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + callbackReturnTo, + allowClientRegistration: true, + }); + + let authResult: Awaited>; + try { + authResult = await mcpAuth(provider, { + serverUrl: new URL(mcpServer.serverUrl), + fetchFn: createTimeoutFetch(MCP_AUTH_FETCH_TIMEOUT_MS), + }); + } catch (error) { + logger.warn('Failed to start connector authorization.', { + serverId: mcpServer.id, + orgId: org.id, + error: getExternalMcpErrorLogFields(error), + }); + void captureEvent('ask_mcp_connector_connection_failed', { + ...eventProperties, + failureReason: getMcpConnectorFailureReason(error), + }); + throw new ServiceErrorException({ + statusCode: StatusCodes.BAD_GATEWAY, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Could not start connector authorization.', + }); + } + + return { + authResult, + authorizationUrl: provider.authorizationUrl ?? null, + }; + }, { + maxWait: MCP_AUTH_TRANSACTION_MAX_WAIT_MS, + timeout: MCP_AUTH_TRANSACTION_TIMEOUT_MS, + }); + + if (connectResult.authResult === 'AUTHORIZED') { + // Already has valid tokens (e.g., refreshed) + void captureEvent('ask_mcp_connector_connection_completed', { + ...eventProperties, + alreadyAuthorized: true, + }); + return { authorizationUrl: null } satisfies ConnectMcpResponse; + } + + if (!connectResult.authorizationUrl) { + void captureEvent('ask_mcp_connector_connection_failed', { + ...eventProperties, + failureReason: 'missing_authorization_url', + }); + throw new Error('MCP auth returned REDIRECT without an authorization URL'); + } + + return { authorizationUrl: connectResult.authorizationUrl } satisfies ConnectMcpResponse; + }) + ); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts new file mode 100644 index 000000000..80281ae17 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts @@ -0,0 +1,4 @@ +export interface ConnectMcpResponse { + /** The external OAuth authorization URL the browser should navigate to. Null if already authorized. */ + authorizationUrl: string | null; +} \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts new file mode 100644 index 000000000..42417d501 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts @@ -0,0 +1,146 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@sourcebot/shared', () => ({ + decryptOAuthToken: vi.fn((value: string) => value), +})); + +const { GET } = await import('./route'); + +function createRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/servers', { method: 'GET' }); +} + +function createPrismaMock() { + return { + mcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + serverUrl: 'https://mcp.sentry.dev/mcp', + }, + { + id: 'server-3', + name: 'GitHub', + sanitizedName: 'github', + serverUrl: 'https://api.githubcopilot.com/mcp', + }, + ]), + }, + userMcpServer: { + findMany: vi.fn().mockResolvedValue([ + { + serverId: 'server-1', + tokens: JSON.stringify({ access_token: 'token', token_type: 'Bearer' }), + tokensExpiresAt: null, + }, + { + serverId: 'server-3', + tokens: JSON.stringify({ access_token: 'expired-token', token_type: 'Bearer' }), + tokensExpiresAt: new Date('2020-01-01T00:00:00.000Z'), + }, + ]), + }, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); +}); + +describe('GET /api/ee/askmcp/servers', () => { + test('returns an empty array when the oauth entitlement is not granted', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(200); + expect(body).toEqual([]); + expect(prisma.mcpServer.findMany).not.toHaveBeenCalled(); + expect(prisma.userMcpServer.findMany).not.toHaveBeenCalled(); + }); + + test('lists org servers and merges only the caller token status', async () => { + const prisma = createPrismaMock(); + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(prisma.mcpServer.findMany).toHaveBeenCalledWith({ + where: { orgId: 1 }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + expect(prisma.userMcpServer.findMany).toHaveBeenCalledWith({ + where: { userId: 'user-1' }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + expect(body).toMatchObject([ + { + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + isConnected: true, + isAuthExpired: false, + }, + { + id: 'server-2', + name: 'Sentry', + sanitizedName: 'sentry', + isConnected: false, + isAuthExpired: false, + }, + { + id: 'server-3', + name: 'GitHub', + sanitizedName: 'github', + isConnected: false, + isAuthExpired: true, + }, + ]); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts new file mode 100644 index 000000000..8ccb3527d --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts @@ -0,0 +1,80 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { withAuth } from '@/middleware/withAuth'; +import { hasEntitlement } from '@/lib/entitlements'; +import { getMcpFaviconUrl } from '@/ee/features/chat/mcp/utils'; +import { getStoredMcpConnectionStatus } from '@/ee/features/chat/mcp/connectionStatus'; +import type { NextRequest } from 'next/server'; + +export interface McpServerWithStatus { + id: string; + name: string; + serverUrl: string; + sanitizedName: string; + faviconUrl: string | undefined; + isConnected: boolean; + isAuthExpired: boolean; +} + +export type GetMcpServersResponse = McpServerWithStatus[]; + +export const GET = apiHandler(async (_request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json([] satisfies GetMcpServersResponse); + } + + const result = await withAuth(async ({ org, user, prisma }) => { + const orgServers = await prisma.mcpServer.findMany({ + where: { orgId: org.id }, + orderBy: { createdAt: 'desc' }, + select: { + id: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }); + + const userServers = await prisma.userMcpServer.findMany({ + where: { userId: user.id }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + const userServerByServerId = new Map(userServers.map((us) => [us.serverId, us])); + + return orgServers.map((server): McpServerWithStatus => { + const userServer = userServerByServerId.get(server.id); + const faviconUrl = getMcpFaviconUrl(server.serverUrl, server.name); + + let isConnected = false; + let isAuthExpired = false; + + const connectionStatus = getStoredMcpConnectionStatus(userServer?.tokens, userServer?.tokensExpiresAt ?? null); + if (connectionStatus.state === 'connected') { + isConnected = true; + } else if (connectionStatus.state === 'expired') { + isAuthExpired = true; + } + + return { + id: server.id, + name: server.name, + serverUrl: server.serverUrl, + sanitizedName: server.sanitizedName, + faviconUrl, + isConnected, + isAuthExpired, + }; + }); + }); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts new file mode 100644 index 000000000..79cf8164d --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts @@ -0,0 +1,75 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { NextRequest } from 'next/server'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + getMcpToolMetadata: vi.fn(), +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@/ee/features/chat/mcp/mcpToolMetadata', () => ({ + getMcpToolMetadata: mocks.getMcpToolMetadata, +})); + +const { GET } = await import('./route'); + +function createRequest() { + return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/tools', { method: 'GET' }); +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.getMcpToolMetadata.mockResolvedValue([]); +}); + +describe('GET /api/ee/askmcp/tools', () => { + test('returns tool metadata for the authenticated viewer without owner-only gating', async () => { + const prisma = {}; + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + role: 'MEMBER', + prisma, + }; + mocks.getMcpToolMetadata.mockResolvedValue([ + { + status: 'available', + serverId: 'server-1', + tools: [{ name: 'search' }], + }, + ]); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(mocks.getMcpToolMetadata).toHaveBeenCalledWith(prisma, 'user-1', 1); + expect(body).toEqual([ + { + status: 'available', + serverId: 'server-1', + tools: [{ name: 'search' }], + }, + ]); + }); + + test('returns access_denied when OAuth is unavailable', async () => { + mocks.hasEntitlement.mockResolvedValue(false); + + const response = await GET(createRequest()); + const body = await response.json(); + + expect(response.status).toBe(403); + expect(body).toMatchObject({ error: 'access_denied' }); + expect(mocks.getMcpToolMetadata).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts new file mode 100644 index 000000000..aea01a7e7 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts @@ -0,0 +1,27 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { withAuth } from '@/middleware/withAuth'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { getMcpToolMetadata } from '@/ee/features/chat/mcp/mcpToolMetadata'; +import type { NextRequest } from 'next/server'; + +export const GET = apiHandler(async (_request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 }, + ); + } + + const result = await withAuth(async ({ org, user, prisma }) => { + return getMcpToolMetadata(prisma, user.id, org.id); + }); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); diff --git a/packages/web/src/ee/features/analytics/analyticsContent.tsx b/packages/web/src/ee/features/analytics/analyticsContent.tsx index 91e62254a..ad2e5bffb 100644 --- a/packages/web/src/ee/features/analytics/analyticsContent.tsx +++ b/packages/web/src/ee/features/analytics/analyticsContent.tsx @@ -2,7 +2,7 @@ import { ChartTooltip } from "@/components/ui/chart" import { Area, AreaChart, ResponsiveContainer, XAxis, YAxis } from "recharts" -import { Users, LucideIcon, Search, ArrowRight, Activity, Calendar, MessageCircle, Wrench, Key, Info } from "lucide-react" +import { Users, LucideIcon, Activity, Calendar, Info } from "lucide-react" import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card" import { ChartContainer } from "@/components/ui/chart" import { useQuery } from "@tanstack/react-query" diff --git a/packages/web/src/ee/features/chat/mcp/actions.test.ts b/packages/web/src/ee/features/chat/mcp/actions.test.ts new file mode 100644 index 000000000..1009c3a4f --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/actions.test.ts @@ -0,0 +1,438 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { McpServerClientInfoSource, OrgRole } from '@sourcebot/db'; +import { ErrorCode } from '@/lib/errorCodes'; + +const mocks = vi.hoisted(() => ({ + authContext: undefined as unknown, + hasEntitlement: vi.fn(), + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + env: { + AUTH_URL: 'https://sourcebot.example.com', + NODE_ENV: 'production', + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, + logger: { + error: vi.fn(), + }, + captureEvent: vi.fn(), + unsafePrisma: { + mcpServer: { + deleteMany: vi.fn(), + findFirst: vi.fn(), + }, + userMcpServer: { + deleteMany: vi.fn(), + }, + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/middleware/withAuth', () => ({ + withAuth: vi.fn((callback: (context: unknown) => unknown) => callback(mocks.authContext)), +})); +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: mocks.hasEntitlement, +})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: mocks.unsafePrisma, +})); +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => mocks.logger, + encryptOAuthToken: mocks.encryptOAuthToken, + env: mocks.env, +})); +vi.mock('@/lib/posthog', () => ({ + captureEvent: mocks.captureEvent, +})); + +const { createMcpServer, createStaticOAuthMcpServer, deleteMcpServer, disconnectMcpServer } = await import('./actions'); + +function createPrismaMock() { + return { + mcpServer: { + findUnique: vi.fn().mockResolvedValue(null), + findFirst: vi.fn().mockResolvedValue(null), + create: vi.fn().mockImplementation(async ({ data }) => ({ + id: 'server-1', + name: data.name, + sanitizedName: data.sanitizedName, + serverUrl: data.serverUrl, + })), + }, + }; +} + +function setAuthContext(role: OrgRole, prisma = createPrismaMock()) { + mocks.authContext = { + org: { id: 1 }, + role, + prisma, + }; + return prisma; +} + +function createStaticOAuthRequest(overrides: Partial<{ + name: string; + serverUrl: string; + clientId: string; + clientSecret: string; +}> = {}) { + return { + name: 'Slack', + serverUrl: 'https://mcp.slack.com/mcp', + clientId: 'client-id', + clientSecret: 'client-secret', + ...overrides, + }; +} + +beforeEach(() => { + vi.clearAllMocks(); + mocks.hasEntitlement.mockResolvedValue(true); + mocks.encryptOAuthToken.mockImplementation((text: string | null | undefined) => text ? `encrypted:${text}` : undefined); + mocks.env.AUTH_URL = 'https://sourcebot.example.com'; + mocks.env.NODE_ENV = 'production'; + mocks.env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS = 5000; + mocks.captureEvent.mockResolvedValue(undefined); +}); + +describe('createMcpServer', () => { + test('owners add an org MCP server without dynamic client information', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createMcpServer(' Linear ', ' https://mcp.linear.app/mcp '); + + expect(result).toEqual({ + id: 'server-1', + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + }); + expect(prisma.mcpServer.create).toHaveBeenCalledWith({ + data: { + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + clientInfo: null, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + orgId: 1, + }, + }); + expect(mocks.captureEvent).toHaveBeenCalledWith('ask_mcp_connector_added', { + source: 'sourcebot-web-client', + entryPoint: 'workspace_settings', + serverId: 'server-1', + serverName: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + authMode: 'dynamic', + }); + }); + + test('members cannot add org MCP servers', async () => { + const prisma = setAuthContext(OrgRole.MEMBER); + + const result = await createMcpServer('Linear', 'https://mcp.linear.app/mcp'); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('owners cannot add org MCP servers when OAuth is unsupported', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.hasEntitlement.mockResolvedValue(false); + + const result = await createMcpServer('Linear', 'https://mcp.linear.app/mcp'); + + expect(result).toMatchObject({ + statusCode: 403, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); +}); + +describe('createStaticOAuthMcpServer', () => { + test('owners add a static OAuth MCP server with encrypted client information', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer({ + name: ' Slack ', + serverUrl: 'https://mcp.slack.com/mcp', + clientId: ' client-id ', + clientSecret: ' client-secret ', + }); + + expect(mocks.encryptOAuthToken).toHaveBeenCalledWith(JSON.stringify({ + client_id: 'client-id', + client_secret: 'client-secret', + })); + expect(prisma.mcpServer.create).toHaveBeenCalledWith({ + data: { + name: 'Slack', + sanitizedName: 'slack', + serverUrl: 'https://mcp.slack.com/mcp', + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + orgId: 1, + }, + }); + expect(JSON.stringify(result)).not.toContain('client-secret'); + expect(result).toEqual({ + id: 'server-1', + name: 'Slack', + sanitizedName: 'slack', + serverUrl: 'https://mcp.slack.com/mcp', + }); + expect(mocks.captureEvent).toHaveBeenCalledWith('ask_mcp_connector_added', { + source: 'sourcebot-web-client', + entryPoint: 'workspace_settings', + serverId: 'server-1', + serverName: 'Slack', + serverUrl: 'https://mcp.slack.com/mcp', + sanitizedName: 'slack', + authMode: 'static', + }); + }); + + test('members cannot add static OAuth MCP servers', async () => { + const prisma = setAuthContext(OrgRole.MEMBER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('rejects static OAuth credentials when production AUTH_URL is not HTTPS', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.env.AUTH_URL = 'http://sourcebot.example.com'; + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Static OAuth client credentials require HTTPS in production.', + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('does not echo client secrets in validation errors', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer({ + name: 'Slack', + serverUrl: 'not-a-url', + clientId: 'client-id', + clientSecret: 'client-secret', + }); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + }); + expect(JSON.stringify(result)).not.toContain('client-secret'); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + }); + + test('rejects static OAuth servers with non-HTTPS server URLs', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + serverUrl: 'http://mcp.slack.com/mcp', + })); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid connector URL. Must be a valid HTTPS URL.', + }); + expect(prisma.mcpServer.findUnique).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with fewer than 3 alphanumeric name characters', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + name: '!!a!', + })); + + expect(result).toMatchObject({ + statusCode: 400, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Connector name must contain at least 3 alphanumeric characters.', + }); + expect(prisma.mcpServer.findUnique).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with a duplicate URL', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + prisma.mcpServer.findUnique.mockResolvedValue({ id: 'existing-server' }); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 409, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'A connector with URL "https://mcp.slack.com/mcp" already exists.', + }); + expect(prisma.mcpServer.findFirst).not.toHaveBeenCalled(); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers with a duplicate sanitized name', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + prisma.mcpServer.findFirst.mockResolvedValue({ id: 'existing-server' }); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest({ + name: 'Slack!!!', + })); + + expect(result).toMatchObject({ + statusCode: 409, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'A connector with a similar name already exists. Please choose a more distinct name.', + }); + expect(prisma.mcpServer.findUnique).toHaveBeenCalledWith({ + where: { + serverUrl_orgId: { + serverUrl: 'https://mcp.slack.com/mcp', + orgId: 1, + }, + }, + select: { id: true }, + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(mocks.encryptOAuthToken).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); + + test('rejects static OAuth servers when client credential encryption fails', async () => { + const prisma = setAuthContext(OrgRole.OWNER); + mocks.encryptOAuthToken.mockReturnValue(undefined); + + const result = await createStaticOAuthMcpServer(createStaticOAuthRequest()); + + expect(result).toMatchObject({ + statusCode: 500, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Failed to store OAuth client credentials.', + }); + expect(prisma.mcpServer.create).not.toHaveBeenCalled(); + expect(JSON.stringify(result)).not.toContain('client-secret'); + }); +}); + +describe('deleteMcpServer', () => { + test('owners delete through the narrowly scoped unsafe client', async () => { + setAuthContext(OrgRole.OWNER); + mocks.unsafePrisma.mcpServer.deleteMany.mockResolvedValue({ count: 1 }); + + await expect(deleteMcpServer('server-1')).resolves.toEqual({ success: true }); + expect(mocks.unsafePrisma.mcpServer.deleteMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + }, + }); + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + }); + + test('members cannot delete org MCP servers', async () => { + setAuthContext(OrgRole.MEMBER); + + const result = await deleteMcpServer('server-1'); + + expect(result).toMatchObject({ + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + }); + expect(mocks.unsafePrisma.mcpServer.deleteMany).not.toHaveBeenCalled(); + }); + + test('owners can delete org MCP servers when OAuth is unsupported', async () => { + setAuthContext(OrgRole.OWNER); + mocks.hasEntitlement.mockResolvedValue(false); + mocks.unsafePrisma.mcpServer.deleteMany.mockResolvedValue({ count: 1 }); + + await expect(deleteMcpServer('server-1')).resolves.toEqual({ success: true }); + + expect(mocks.hasEntitlement).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.mcpServer.deleteMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + }, + }); + }); +}); + +describe('disconnectMcpServer', () => { + test('disconnects a personal connector and tracks the disconnect', async () => { + const prisma = { + mcpServer: { + findFirst: vi.fn().mockResolvedValue({ + id: 'server-1', + name: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }), + }, + userMcpServer: { + deleteMany: vi.fn().mockResolvedValue({ count: 1 }), + }, + }; + mocks.authContext = { + org: { id: 1 }, + user: { id: 'user-1' }, + prisma, + }; + + await expect(disconnectMcpServer('server-1', 'account_settings')).resolves.toEqual({ success: true }); + + expect(prisma.mcpServer.findFirst).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + }, + select: { + id: true, + name: true, + serverUrl: true, + sanitizedName: true, + clientInfoSource: true, + }, + }); + expect(prisma.userMcpServer.deleteMany).toHaveBeenCalledWith({ + where: { + serverId: 'server-1', + userId: 'user-1', + }, + }); + expect(mocks.unsafePrisma.mcpServer.findFirst).not.toHaveBeenCalled(); + expect(mocks.unsafePrisma.userMcpServer.deleteMany).not.toHaveBeenCalled(); + expect(mocks.captureEvent).toHaveBeenCalledWith('ask_mcp_connector_disconnected', { + source: 'sourcebot-web-client', + entryPoint: 'account_settings', + serverId: 'server-1', + serverName: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + sanitizedName: 'linear', + authMode: 'dynamic', + }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/actions.ts b/packages/web/src/ee/features/chat/mcp/actions.ts new file mode 100644 index 000000000..7ce05ede4 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/actions.ts @@ -0,0 +1,366 @@ +'use server'; + +import { sew } from '@/middleware/sew'; +import { ErrorCode } from '@/lib/errorCodes'; +import { requestBodySchemaValidationError, ServiceError } from '@/lib/serviceError'; +import { withAuth } from '@/middleware/withAuth'; +import { withMinimumOrgRole } from '@/middleware/withMinimumOrgRole'; +import { __unsafePrisma } from '@/prisma'; +import { isServiceError } from '@/lib/utils'; +import { McpServerClientInfoSource, OrgRole, type PrismaClient } from '@sourcebot/db'; +import { StatusCodes } from 'http-status-codes'; +import { z } from 'zod'; +import { sanitizeMcpServerName } from './utils'; +import { hasEntitlement } from '@/lib/entitlements'; +import { oauthNotSupported } from './errors'; +import { checkMcpServerDcrSupport } from './dcrDiscovery'; +import { encryptOAuthToken, env } from '@sourcebot/shared'; +import { captureEvent } from '@/lib/posthog'; +import { getMcpAuthMode } from './analytics'; +import type { McpConnectorEntryPoint } from '@/lib/posthogEvents'; + +const MCP_DCR_DISCOVERY_TIMEOUT_MS = Math.min(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS, 10000); +const createStaticOAuthMcpServerSchema = z.object({ + name: z.string().trim().min(1), + serverUrl: z.string().trim().url(), + clientId: z.string().trim().min(1), + clientSecret: z.string().trim().min(1), +}); + +export type CreateStaticOAuthMcpServerRequest = z.infer; + +export interface CreateStaticOAuthMcpServerResponse { + id: string; + name: string; + sanitizedName: string; + serverUrl: string; +} + +type McpServerPrismaClient = Pick; + +interface PreparedMcpServerCreate { + displayName: string; + normalizedServerUrl: string; + sanitizedName: string; +} + +function createTimeoutFetch(timeoutMs: number): typeof fetch { + return async (input, init) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const signal = init?.signal + ? AbortSignal.any([init.signal, timeoutSignal]) + : timeoutSignal; + + return fetch(input, { + ...init, + signal, + }); + }; +} + +function invalidRequest(message: string): ServiceError { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message, + }; +} + +function assertHttpsAuthUrlInProduction(): ServiceError | undefined { + if (env.NODE_ENV !== 'production') { + return undefined; + } + + if (new URL(env.AUTH_URL).protocol === 'https:') { + return undefined; + } + + return invalidRequest('Static OAuth client credentials require HTTPS in production.'); +} + +async function prepareMcpServerCreate({ + prisma, + orgId, + name, + serverUrl, +}: { + prisma: McpServerPrismaClient; + orgId: number; + name: string; + serverUrl: string; +}): Promise { + const displayName = name.trim(); + const normalizedServerUrl = serverUrl.trim(); + const urlResult = z.string().url().safeParse(normalizedServerUrl); + const protocol = urlResult.success ? new URL(normalizedServerUrl).protocol : undefined; + if (!urlResult.success || protocol !== 'https:') { + return invalidRequest('Invalid connector URL. Must be a valid HTTPS URL.'); + } + + const sanitizedName = sanitizeMcpServerName(displayName); + const alphanumericCount = (sanitizedName.match(/[a-z0-9]/g) ?? []).length; + if (alphanumericCount < 3) { + return invalidRequest('Connector name must contain at least 3 alphanumeric characters.'); + } + + const existingServer = await prisma.mcpServer.findUnique({ + where: { + serverUrl_orgId: { + serverUrl: normalizedServerUrl, + orgId, + }, + }, + select: { id: true }, + }); + if (existingServer) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: `A connector with URL "${normalizedServerUrl}" already exists.`, + } satisfies ServiceError; + } + + const existingName = await prisma.mcpServer.findFirst({ + where: { + orgId, + sanitizedName, + }, + select: { id: true }, + }); + if (existingName) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: 'A connector with a similar name already exists. Please choose a more distinct name.', + } satisfies ServiceError; + } + + return { + displayName, + normalizedServerUrl, + sanitizedName, + }; +} + +export const checkMcpServerDynamicClientRegistration = async (serverUrl: string) => sew(() => + withAuth(async ({ role }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const normalizedServerUrl = serverUrl.trim(); + const urlResult = z.string().url().safeParse(normalizedServerUrl); + const protocol = urlResult.success ? new URL(normalizedServerUrl).protocol : undefined; + if (!urlResult.success || protocol !== 'https:') { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid connector URL. Must be a valid HTTPS URL.', + } satisfies ServiceError; + } + + try { + return await checkMcpServerDcrSupport( + normalizedServerUrl, + createTimeoutFetch(MCP_DCR_DISCOVERY_TIMEOUT_MS), + ); + } catch { + return { + statusCode: StatusCodes.BAD_GATEWAY, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Could not check whether this connector supports dynamic client registration.', + } satisfies ServiceError; + } + }))); + +export const createStaticOAuthMcpServer = async ( + body: CreateStaticOAuthMcpServerRequest, +) => { + const parsed = createStaticOAuthMcpServerSchema.safeParse(body); + if (!parsed.success) { + return requestBodySchemaValidationError(parsed.error); + } + + return sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const httpsError = assertHttpsAuthUrlInProduction(); + if (httpsError) { + return httpsError; + } + + const preparedServer = await prepareMcpServerCreate({ + prisma, + orgId: org.id, + name: parsed.data.name, + serverUrl: parsed.data.serverUrl, + }); + if (isServiceError(preparedServer)) { + return preparedServer; + } + + const clientInfo = encryptOAuthToken(JSON.stringify({ + client_id: parsed.data.clientId, + client_secret: parsed.data.clientSecret, + })); + if (!clientInfo) { + return { + statusCode: StatusCodes.INTERNAL_SERVER_ERROR, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: 'Failed to store OAuth client credentials.', + } satisfies ServiceError; + } + + const mcpServer = await prisma.mcpServer.create({ + data: { + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: preparedServer.normalizedServerUrl, + clientInfo, + clientInfoSource: McpServerClientInfoSource.STATIC, + orgId: org.id, + }, + }); + + void captureEvent('ask_mcp_connector_added', { + source: 'sourcebot-web-client', + entryPoint: 'workspace_settings', + serverId: mcpServer.id, + serverName: preparedServer.displayName, + serverUrl: mcpServer.serverUrl, + sanitizedName: preparedServer.sanitizedName, + authMode: 'static', + }); + + return { + id: mcpServer.id, + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: mcpServer.serverUrl, + }; + }))); +} + +export const createMcpServer = async (name: string, serverUrl: string) => sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + if (!(await hasEntitlement('oauth'))) { + return oauthNotSupported(); + } + + const preparedServer = await prepareMcpServerCreate({ + prisma, + orgId: org.id, + name, + serverUrl, + }); + if (isServiceError(preparedServer)) { + return preparedServer; + } + + const mcpServer = await prisma.mcpServer.create({ + data: { + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: preparedServer.normalizedServerUrl, + clientInfo: null, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + orgId: org.id, + }, + }); + + void captureEvent('ask_mcp_connector_added', { + source: 'sourcebot-web-client', + entryPoint: 'workspace_settings', + serverId: mcpServer.id, + serverName: preparedServer.displayName, + serverUrl: mcpServer.serverUrl, + sanitizedName: preparedServer.sanitizedName, + authMode: getMcpAuthMode(McpServerClientInfoSource.DYNAMIC), + }); + + return { + id: mcpServer.id, + name: preparedServer.displayName, + sanitizedName: preparedServer.sanitizedName, + serverUrl: mcpServer.serverUrl, + }; + }))); + +export const deleteMcpServer = async (serverId: string) => sew(() => + withAuth(async ({ org, role }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + const result = await __unsafePrisma.mcpServer.deleteMany({ + where: { + id: serverId, + orgId: org.id, + }, + }); + + if (result.count === 0) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'Connector not found', + } satisfies ServiceError; + } + + return { success: true }; + }))); + +export const disconnectMcpServer = async (serverId: string, entryPoint: McpConnectorEntryPoint) => sew(() => + withAuth(async ({ org, user, prisma }) => { + const server = await prisma.mcpServer.findFirst({ + where: { + id: serverId, + orgId: org.id, + }, + select: { + id: true, + name: true, + serverUrl: true, + sanitizedName: true, + clientInfoSource: true, + }, + }); + + if (!server) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'Connector not found', + } satisfies ServiceError; + } + + const result = await prisma.userMcpServer.deleteMany({ + where: { + serverId, + userId: user.id, + }, + }); + + if (result.count === 0) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'No connection found for this connector.', + } satisfies ServiceError; + } + + void captureEvent('ask_mcp_connector_disconnected', { + source: 'sourcebot-web-client', + entryPoint, + serverId: server.id, + serverName: server.name, + serverUrl: server.serverUrl, + sanitizedName: server.sanitizedName, + authMode: getMcpAuthMode(server.clientInfoSource), + }); + + return { success: true }; + })); diff --git a/packages/web/src/ee/features/chat/mcp/analytics.ts b/packages/web/src/ee/features/chat/mcp/analytics.ts new file mode 100644 index 000000000..b21fe7848 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/analytics.ts @@ -0,0 +1,39 @@ +import { McpServerClientInfoSource } from '@sourcebot/db'; +import type { McpConnectorAuthMode, McpConnectorEntryPoint } from '@/lib/posthogEvents'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; + +export function getMcpConnectorEntryPoint(returnTo: string | undefined): McpConnectorEntryPoint { + if (returnTo?.startsWith('/chat')) { + return 'chat'; + } + if (returnTo?.startsWith('/settings/accountAskAgent')) { + return 'account_settings'; + } + if (returnTo?.startsWith('/settings/workspaceAskAgent')) { + return 'workspace_settings'; + } + + return 'unknown'; +} + +export function getMcpAuthMode(clientInfoSource: McpServerClientInfoSource): McpConnectorAuthMode { + return clientInfoSource === McpServerClientInfoSource.STATIC ? 'static' : 'dynamic'; +} + +export function getMcpConnectorFailureReason(error: unknown): string { + const fields = getExternalMcpErrorLogFields(error); + if (fields.reason) { + return fields.reason; + } + if (fields.oauthError) { + return fields.oauthError; + } + if (fields.statusCode) { + return `status_${fields.statusCode}`; + } + if (fields.errorClass) { + return fields.errorClass; + } + + return 'unknown'; +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectMcpButton.tsx b/packages/web/src/ee/features/chat/mcp/components/connectMcpButton.tsx new file mode 100644 index 000000000..53bff998e --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectMcpButton.tsx @@ -0,0 +1,38 @@ +'use client'; + +import { LoadingButton } from '@/components/ui/loading-button'; +import { ExternalLink } from 'lucide-react'; +import type { ButtonProps } from '@/components/ui/button'; +import { useConnectMcp } from '@/ee/features/chat/mcp/hooks/useConnectMcp'; + +interface ConnectMcpButtonProps { + serverId: string; + isConnected?: boolean; + isAuthExpired?: boolean; + size?: ButtonProps['size']; + variant?: ButtonProps['variant']; + returnTo?: string; + className?: string; +} + +export function ConnectMcpButton({ serverId, isConnected, isAuthExpired, size, variant, returnTo, className }: ConnectMcpButtonProps) { + const { connect, loadingServerId } = useConnectMcp({ returnTo }); + const loading = loadingServerId === serverId; + + const isSuggested = !isConnected && !isAuthExpired; + const buttonLabel = isSuggested ? "Connect" : "Reconnect"; + const defaultVariant = isConnected ? "outline" as const : undefined; + + return ( + connect(serverId)} + loading={loading} + variant={variant ?? defaultVariant} + size={size} + className={className} + > + {buttonLabel} + {!isSuggested && } + + ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorCard.test.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorCard.test.tsx new file mode 100644 index 000000000..5e843d857 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorCard.test.tsx @@ -0,0 +1,93 @@ +import { afterEach, describe, expect, test } from 'vitest'; +import { cleanup, fireEvent, render, screen } from '@testing-library/react'; +import { ConnectorCard } from './connectorCard'; +import type { McpServerToolUsageSummary, ServerToolsEntry } from '@/ee/features/chat/mcp/types'; + +afterEach(() => { + cleanup(); +}); + +function availableEntry(): Extract { + return { + status: 'available', + serverId: 'server-1', + tools: [ + { name: 'search_issues', title: 'Search issues' }, + { name: 'get_issue' }, + { name: 'create_issue' }, + ], + }; +} + +function usageSummary(): McpServerToolUsageSummary { + return { + totalCalls: 6, + usedToolCount: 2, + tools: [ + { toolName: 'search_issues', totalCalls: 4, usageSharePercent: 66.666 }, + { toolName: 'get_issue', totalCalls: 2, usageSharePercent: 33.333 }, + ], + }; +} + +describe('ConnectorCard', () => { + test('shows only one expanded tools or usage panel at a time', () => { + render( + Connected
} + actionButtons={null} + />, + ); + + const toolsTrigger = screen.getByRole('button', { name: /3 tools/ }); + const usageTrigger = screen.getByRole('button', { name: /6 tool calls/ }); + + expect(toolsTrigger.getAttribute('aria-controls')).toBeTruthy(); + expect(usageTrigger.getAttribute('aria-controls')).toBeTruthy(); + + fireEvent.click(toolsTrigger); + + expect(screen.getByRole('button', { name: 'Search issues' })).toBeTruthy(); + expect(document.getElementById(toolsTrigger.getAttribute('aria-controls') ?? '')).toBeTruthy(); + expect(screen.queryByText('Lifetime tool usage')).toBeNull(); + + fireEvent.click(usageTrigger); + + expect(screen.getByText('Lifetime tool usage')).toBeTruthy(); + expect(document.getElementById(usageTrigger.getAttribute('aria-controls') ?? '')).toBeTruthy(); + expect(screen.queryByRole('button', { name: 'Search issues' })).toBeNull(); + + fireEvent.click(toolsTrigger); + + expect(screen.getByRole('button', { name: 'Search issues' })).toBeTruthy(); + expect(screen.queryByText('Lifetime tool usage')).toBeNull(); + }); + + test('hides usage disclosure for connectors with no tool calls', () => { + render( + Connected
} + actionButtons={null} + />, + ); + + expect(screen.getByRole('button', { name: /3 tools/ })).toBeTruthy(); + expect(screen.queryByRole('button', { name: /0 tool calls/ })).toBeNull(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx new file mode 100644 index 000000000..3fc7feaf7 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx @@ -0,0 +1,107 @@ +'use client'; + +import { useId, useState, type ReactNode } from 'react'; +import { Card, CardContent } from '@/components/ui/card'; +import { ConnectorRowInfo } from '@/ee/features/chat/mcp/components/connectorRowInfo'; +import { ConnectorToolList, ConnectorToolTrigger } from '@/ee/features/chat/mcp/components/connectorToolDisclosure'; +import { ConnectorToolUsageList, ConnectorToolUsageTrigger } from '@/ee/features/chat/mcp/components/connectorToolUsageDisclosure'; +import type { McpServerToolUsageSummary, ServerToolsEntry } from '@/ee/features/chat/mcp/types'; + +interface ConnectorCardProps { + faviconUrl: string | undefined; + name: string; + serverUrl: string; + + isConnected: boolean; + isAuthExpired?: boolean; + isOAuthAvailable?: boolean; + isStatusUnavailable?: boolean; + toolEntry?: ServerToolsEntry; + toolUsage?: McpServerToolUsageSummary; + isToolsLoading?: boolean; + isToolsError?: boolean; + onRetryTools?: () => void; + + statusBadge: ReactNode; + actionButtons: ReactNode; +} + +export function ConnectorCard({ + faviconUrl, + name, + serverUrl, + isConnected, + isAuthExpired, + isOAuthAvailable, + isStatusUnavailable, + toolEntry, + toolUsage, + isToolsLoading = false, + isToolsError = false, + onRetryTools, + statusBadge, + actionButtons, +}: ConnectorCardProps) { + const [openPanel, setOpenPanel] = useState<'tools' | 'usage' | null>(null); + const panelIdPrefix = useId(); + const toolsPanelId = `${panelIdPrefix}-tools`; + const usagePanelId = `${panelIdPrefix}-usage`; + const availableToolEntry = toolEntry?.status === 'available' ? toolEntry : undefined; + const hasToolList = !!availableToolEntry; + const hasToolUsage = (toolUsage?.totalCalls ?? 0) > 0; + const isToolListOpen = openPanel === 'tools'; + const isToolUsageOpen = hasToolUsage && openPanel === 'usage'; + const isLoadingToolsForServer = isConnected && !availableToolEntry && isToolsLoading; + + return ( + + +
+ +
+ {statusBadge} + setOpenPanel(open && hasToolList ? 'tools' : null)} + onRetry={onRetryTools} + /> + {hasToolUsage && toolUsage && ( + setOpenPanel(open ? 'usage' : null)} + /> + )} +
+
+
+ {actionButtons} +
+
+ + {hasToolUsage && toolUsage && ( + + )} +
+
+ ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorRowInfo.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorRowInfo.tsx new file mode 100644 index 000000000..b7dd504f9 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorRowInfo.tsx @@ -0,0 +1,41 @@ +import { McpFavicon } from "./mcpFavicon"; +import { cn } from "@/lib/utils"; + +export function getDisplayServerUrl(serverUrl: string) { + try { + const url = new URL(serverUrl); + return `${url.host}${url.pathname}${url.search}`.replace(/\/$/, ""); + } catch { + return serverUrl; + } +} + +interface ConnectorRowInfoProps { + faviconUrl: string | undefined; + name: string; + serverUrl: string; + children?: React.ReactNode; + size?: 'sm' | 'default'; +} + +export function ConnectorRowInfo({ faviconUrl, name, serverUrl, children, size = 'default' }: ConnectorRowInfoProps) { + return ( + <> +
+ +
+
+

+ {name || serverUrl} +

+

+ {getDisplayServerUrl(serverUrl)} +

+ {children} +
+ + ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.test.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.test.tsx new file mode 100644 index 000000000..8fc018779 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.test.tsx @@ -0,0 +1,166 @@ +import { afterEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, fireEvent, render, screen } from '@testing-library/react'; +import { ConnectorToolList, ConnectorToolTrigger } from './connectorToolDisclosure'; +import type { ServerToolsEntry } from '@/ee/features/chat/mcp/types'; + +afterEach(() => { + cleanup(); +}); + +function renderToolTrigger(props: React.ComponentProps) { + return render(); +} + +function availableEntry(overrides: Partial> = {}): Extract { + return { + status: 'available', + serverId: 'server-1', + tools: [ + { name: 'search', title: 'Search', description: 'Search issues', annotations: { readOnlyHint: true } }, + { name: 'delete_issue', description: 'Delete an issue', annotations: { destructiveHint: true, idempotentHint: true } }, + ], + ...overrides, + }; +} + +describe('ConnectorToolTrigger', () => { + test('renders an expandable count for available tools', () => { + renderToolTrigger({ + isConnected: true, + toolEntry: availableEntry(), + isOpen: false, + }); + + expect(screen.getByRole('button', { name: /2 tools/ })).toBeTruthy(); + }); + + test('uses plus count language only for list truncation', () => { + renderToolTrigger({ + isConnected: true, + toolEntry: availableEntry({ tools: [{ name: 'search' }], truncated: true }), + }); + + expect(screen.getByRole('button', { name: /1\+ tools/ })).toBeTruthy(); + }); + + test('renders unavailable state before connection-specific states', () => { + renderToolTrigger({ + isConnected: false, + isOAuthAvailable: false, + }); + + expect(screen.getByText('Tools unavailable')).toBeTruthy(); + expect(screen.queryByText('Connect to see tools')).toBeNull(); + }); + + test('renders actionable labels for disconnected and expired auth states', () => { + const { rerender } = render( + , + ); + + expect(screen.getByText('Connect to see tools')).toBeTruthy(); + + rerender( + , + ); + + expect(screen.getByText('Reconnect to see tools')).toBeTruthy(); + }); + + test('renders loading and retryable error states for connected servers', () => { + const onRetry = vi.fn(); + const { rerender } = render( + , + ); + + expect(screen.getByText('Loading tools...')).toBeTruthy(); + + rerender( + , + ); + + expect(screen.getByText('Tools timed out')).toBeTruthy(); + fireEvent.click(screen.getByRole('button', { name: /Retry/ })); + expect(onRetry).toHaveBeenCalledTimes(1); + }); + + test('maps auth_failed errors to reconnect language', () => { + renderToolTrigger({ + isConnected: true, + toolEntry: { status: 'error', serverId: 'server-1', reason: 'auth_failed' }, + }); + + expect(screen.getByText('Reconnect to see tools')).toBeTruthy(); + }); +}); + +describe('ConnectorToolList', () => { + test('renders compact tool badges and expands detail on click', () => { + render( + , + ); + + // Both tool badges are visible + expect(screen.getByRole('button', { name: 'Search' })).toBeTruthy(); + expect(screen.getByRole('button', { name: 'delete_issue' })).toBeTruthy(); + + // No detail shown yet + expect(screen.queryByText('Search issues')).toBeNull(); + expect(screen.queryByText('Read-only')).toBeNull(); + + // Click to expand detail + fireEvent.click(screen.getByRole('button', { name: 'Search' })); + expect(screen.getByText('Search issues')).toBeTruthy(); + expect(screen.getByText('search')).toBeTruthy(); + expect(screen.getByText('Read-only')).toBeTruthy(); + + // Click another tool — previous detail closes, new one opens + fireEvent.click(screen.getByRole('button', { name: 'delete_issue' })); + expect(screen.queryByText('Search issues')).toBeNull(); + expect(screen.getByText('Delete an issue')).toBeTruthy(); + expect(screen.getByText('Destructive')).toBeTruthy(); + expect(screen.getByText('Idempotent')).toBeTruthy(); + + // Click same tool again to collapse + fireEvent.click(screen.getByRole('button', { name: 'delete_issue' })); + expect(screen.queryByText('Delete an issue')).toBeNull(); + }); + + test('renders an empty-tools message for available servers with no tools', () => { + render( + , + ); + + expect(screen.getByText('No tools exposed by this connector.')).toBeTruthy(); + }); + + test('clears selected tool detail when closed', () => { + const { rerender } = render( + , + ); + + fireEvent.click(screen.getByRole('button', { name: 'Search' })); + expect(screen.getByText('Search issues')).toBeTruthy(); + + rerender( + , + ); + rerender( + , + ); + + expect(screen.queryByText('Search issues')).toBeNull(); + }); + + test('does not render list content for non-available entries', () => { + render( + , + ); + + expect(screen.queryByText('No tools exposed by this connector.')).toBeNull(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx new file mode 100644 index 000000000..659fc9702 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx @@ -0,0 +1,216 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { Badge } from '@/components/ui/badge'; +import { cn } from '@/lib/utils'; +import { pluralize } from '@/ee/features/chat/mcp/utils'; +import type { ServerToolsEntry, ToolMetadataErrorReason, ToolSummary } from '@/ee/features/chat/mcp/types'; +import { ChevronDownIcon, RefreshCwIcon, WrenchIcon } from 'lucide-react'; + +function getErrorLabel(reason: ToolMetadataErrorReason) { + switch (reason) { + case 'timeout': + return 'Tools timed out'; + case 'auth_failed': + return 'Reconnect to see tools'; + case 'unsupported': + return 'Tools unsupported'; + case 'connection_failed': + case 'unknown': + return 'Tools unavailable'; + } +} + +function getToolCountLabel(entry: Extract) { + const countLabel = `${entry.tools.length}${entry.truncated ? '+' : ''}`; + const nounCount = entry.truncated ? 2 : entry.tools.length; + return `${countLabel} ${pluralize(nounCount, 'tool')}`; +} + +interface ConnectorToolTriggerProps { + isConnected: boolean; + isAuthExpired?: boolean; + isOAuthAvailable?: boolean; + isStatusUnavailable?: boolean; + toolEntry?: ServerToolsEntry; + isLoading?: boolean; + isToolsQueryError?: boolean; + isOpen?: boolean; + controlsId?: string; + onOpenChange?: (open: boolean) => void; + onRetry?: () => void; +} + +export function ConnectorToolTrigger({ + isConnected, + isAuthExpired = false, + isOAuthAvailable = true, + isStatusUnavailable = false, + toolEntry, + isLoading = false, + isToolsQueryError = false, + isOpen = false, + controlsId, + onOpenChange, + onRetry, +}: ConnectorToolTriggerProps) { + const availableEntry = toolEntry?.status === 'available' ? toolEntry : undefined; + const errorEntry = toolEntry?.status === 'error' ? toolEntry : undefined; + const canExpand = !!availableEntry; + + if (canExpand) { + return ( + + ); + } + + let label = 'Tools unavailable'; + let canRetry = false; + + if (!isOAuthAvailable || isStatusUnavailable) { + label = 'Tools unavailable'; + } else if (!isConnected && isAuthExpired) { + label = 'Reconnect to see tools'; + } else if (!isConnected) { + label = 'Connect to see tools'; + } else if (isLoading) { + label = 'Loading tools...'; + } else if (errorEntry) { + label = getErrorLabel(errorEntry.reason); + canRetry = true; + } else if (isToolsQueryError) { + label = 'Tools unavailable'; + canRetry = true; + } + + return ( + + + {label} + {canRetry && onRetry && ( + + )} + + ); +} + +function ToolHintBadges({ tool }: { tool: ToolSummary }) { + const annotations = tool.annotations; + if (!annotations) { + return null; + } + + return ( + <> + {annotations.readOnlyHint === true && ( + + Read-only + + )} + {annotations.destructiveHint === true && ( + + Destructive + + )} + {annotations.idempotentHint === true && ( + + Idempotent + + )} + + ); +} + +function ToolDetail({ tool }: { tool: ToolSummary }) { + const displayName = tool.title ?? tool.name; + + return ( +
+
+ {displayName} + {tool.title && tool.title !== tool.name && ( + {tool.name} + )} + +
+ {tool.description && ( +

{tool.description}

+ )} +
+ ); +} + +interface ConnectorToolListProps { + toolEntry?: ServerToolsEntry; + isOpen?: boolean; + id?: string; +} + +export function ConnectorToolList({ toolEntry, isOpen = true, id }: ConnectorToolListProps) { + const [selectedTool, setSelectedTool] = useState(null); + + useEffect(() => { + if (!isOpen) { + setSelectedTool(null); + } + }, [isOpen]); + + if (!isOpen || toolEntry?.status !== 'available') { + return null; + } + + const activeTool = toolEntry.tools.find((t) => t.name === selectedTool); + + return ( +
+ {toolEntry.tools.length === 0 ? ( +

No tools exposed by this connector.

+ ) : ( +
+
+ {toolEntry.tools.map((tool) => { + const displayName = tool.title ?? tool.name; + const isSelected = selectedTool === tool.name; + + return ( + + ); + })} +
+ {activeTool && ( + + )} +
+ )} +
+ ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.test.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.test.tsx new file mode 100644 index 000000000..d81da94d5 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.test.tsx @@ -0,0 +1,87 @@ +import { afterEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, fireEvent, render, screen } from '@testing-library/react'; +import { ConnectorToolUsageList, ConnectorToolUsageTrigger } from './connectorToolUsageDisclosure'; +import type { McpServerToolUsageSummary, ServerToolsEntry } from '@/ee/features/chat/mcp/types'; + +afterEach(() => { + cleanup(); +}); + +function usageSummary(overrides: Partial = {}): McpServerToolUsageSummary { + return { + totalCalls: 6, + usedToolCount: 2, + tools: [ + { toolName: 'search_issues', totalCalls: 4, usageSharePercent: 66.666 }, + { toolName: 'get_issue', totalCalls: 2, usageSharePercent: 33.333 }, + ], + ...overrides, + }; +} + +function availableEntry(): Extract { + return { + status: 'available', + serverId: 'server-1', + tools: [ + { name: 'search_issues', title: 'Search issues' }, + { name: 'get_issue' }, + { name: 'create_issue' }, + ], + }; +} + +describe('ConnectorToolUsageTrigger', () => { + test('renders total tool calls and toggles open state', () => { + const onOpenChange = vi.fn(); + render( + , + ); + + fireEvent.click(screen.getByRole('button', { name: /6 tool calls/ })); + + expect(onOpenChange).toHaveBeenCalledWith(true); + }); +}); + +describe('ConnectorToolUsageList', () => { + test('renders used tools with usage bars and footer', () => { + render( + , + ); + + expect(screen.getByText('Lifetime tool usage')).toBeTruthy(); + expect(screen.getByText('Search issues')).toBeTruthy(); + expect(screen.getByText('search_issues')).toBeTruthy(); + expect(screen.getByText('get_issue')).toBeTruthy(); + expect(screen.getByText('6 total tool calls across 2 of 3 tools')).toBeTruthy(); + }); + + test('renders empty usage state', () => { + render( + , + ); + + expect(screen.getByText('No tool calls yet.')).toBeTruthy(); + }); + + test('does not render when closed', () => { + render( + , + ); + + expect(screen.queryByText('Lifetime tool usage')).toBeNull(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.tsx new file mode 100644 index 000000000..3f4f78c53 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorToolUsageDisclosure.tsx @@ -0,0 +1,113 @@ +'use client'; + +import { cn } from '@/lib/utils'; +import { + formatCount, + formatUsageSharePercent, + pluralize, +} from '@/ee/features/chat/mcp/utils'; +import type { McpServerToolUsageSummary, ServerToolsEntry } from '@/ee/features/chat/mcp/types'; +import { BarChart3Icon, ChevronDownIcon } from 'lucide-react'; + +interface ConnectorToolUsageTriggerProps { + toolUsage: McpServerToolUsageSummary; + isOpen?: boolean; + controlsId?: string; + onOpenChange?: (open: boolean) => void; +} + +export function ConnectorToolUsageTrigger({ + toolUsage, + isOpen = false, + controlsId, + onOpenChange, +}: ConnectorToolUsageTriggerProps) { + return ( + + ); +} + +interface ConnectorToolUsageListProps { + toolUsage: McpServerToolUsageSummary; + toolEntry?: Extract; + isOpen?: boolean; + id?: string; +} + +export function ConnectorToolUsageList({ + toolUsage, + toolEntry, + isOpen = true, + id, +}: ConnectorToolUsageListProps) { + if (!isOpen) { + return null; + } + + const topToolTotal = toolUsage.tools[0]?.totalCalls ?? 0; + const toolByName = new Map(toolEntry?.tools.map((tool) => [tool.name, tool]) ?? []); + + return ( +
+

Lifetime tool usage

+ {toolUsage.totalCalls === 0 || toolUsage.tools.length === 0 ? ( +

No tool calls yet.

+ ) : ( + <> +
+ {toolUsage.tools.map((tool) => { + const toolMetadata = toolByName.get(tool.toolName); + const displayName = toolMetadata?.title ?? tool.toolName; + const barWidth = topToolTotal > 0 + ? Math.min(100, (tool.totalCalls / topToolTotal) * 100) + : 0; + + return ( +
+
+ + {displayName} + + + {formatCount(tool.totalCalls)} ({formatUsageSharePercent(tool.usageSharePercent)}) + +
+
+
0 ? '2px' : undefined, + }} + /> +
+ {toolMetadata?.title && toolMetadata.title !== tool.toolName && ( +

+ {tool.toolName} +

+ )} +
+ ); + })} +
+

+ {formatCount(toolUsage.totalCalls)} total tool calls across{' '} + {toolEntry + ? `${formatCount(toolUsage.usedToolCount)} of ${formatCount(toolEntry.tools.length)} ${pluralize(toolEntry.tools.length, 'tool')}` + : `${formatCount(toolUsage.usedToolCount)} used ${pluralize(toolUsage.usedToolCount, 'tool')}`} +

+ + )} +
+ ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/mcpFavicon.tsx b/packages/web/src/ee/features/chat/mcp/components/mcpFavicon.tsx new file mode 100644 index 000000000..2220fc516 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/mcpFavicon.tsx @@ -0,0 +1,24 @@ +'use client'; + +import { Plug } from "lucide-react"; +import { useState } from "react"; + +interface McpFaviconProps { + faviconUrl: string | undefined; + className?: string; +} + +export const McpFavicon = ({ faviconUrl, className = "w-4 h-4" }: McpFaviconProps) => { + const [failed, setFailed] = useState(false); + if (faviconUrl && !failed) { + return ( + setFailed(true)} + className={`${className} flex-shrink-0`} + alt="" + /> + ); + } + return ; +}; \ No newline at end of file diff --git a/packages/web/src/ee/features/chat/mcp/connectionStatus.test.ts b/packages/web/src/ee/features/chat/mcp/connectionStatus.test.ts new file mode 100644 index 000000000..d0c1b538b --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/connectionStatus.test.ts @@ -0,0 +1,81 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +const decryptOAuthToken = vi.hoisted(() => vi.fn()); + +vi.mock('@sourcebot/shared', () => ({ + decryptOAuthToken, +})); + +const { getStoredMcpConnectionStatus, isTokenExpiredWithNoRefresh } = await import('./connectionStatus'); + +const PAST = new Date('2020-01-01'); +const FUTURE = new Date('2099-01-01'); +const TOKEN_NO_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer' }; +const TOKEN_WITH_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer', refresh_token: 'ref' }; + +beforeEach(() => { + decryptOAuthToken.mockReset(); + decryptOAuthToken.mockImplementation((value: string) => value); +}); + +describe('isTokenExpiredWithNoRefresh', () => { + test('returns true when an access token is expired and has no refresh token', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, PAST)).toBe(true); + }); + + test('returns false when a refresh token is present', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_WITH_REFRESH, PAST)).toBe(false); + }); + + test('returns false when there is no stored expiration', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, null)).toBe(false); + }); + + test('returns false when the access token has not expired', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, FUTURE)).toBe(false); + }); +}); + +describe('getStoredMcpConnectionStatus', () => { + test('returns not_connected when no encrypted tokens are stored', () => { + expect(getStoredMcpConnectionStatus(null, null)).toEqual({ state: 'not_connected' }); + }); + + test('returns not_connected when tokens cannot be decrypted', () => { + decryptOAuthToken.mockReturnValueOnce(null); + + expect(getStoredMcpConnectionStatus('encrypted', null)).toEqual({ state: 'not_connected' }); + }); + + test('returns not_connected when decrypted tokens are malformed', () => { + decryptOAuthToken.mockReturnValueOnce('not json'); + + expect(getStoredMcpConnectionStatus('encrypted', null)).toEqual({ state: 'not_connected' }); + }); + + test('returns not_connected when decrypted tokens are missing required OAuth fields', () => { + expect(getStoredMcpConnectionStatus(JSON.stringify({ token_type: 'Bearer' }), null)).toEqual({ state: 'not_connected' }); + expect(getStoredMcpConnectionStatus(JSON.stringify({ access_token: 'tok' }), null)).toEqual({ state: 'not_connected' }); + }); + + test('returns not_connected when optional OAuth fields have unexpected types', () => { + expect(getStoredMcpConnectionStatus(JSON.stringify({ + access_token: 'tok', + token_type: 'Bearer', + refresh_token: true, + }), null)).toEqual({ state: 'not_connected' }); + }); + + test('returns expired when an access token has expired and cannot be refreshed', () => { + const status = getStoredMcpConnectionStatus(JSON.stringify(TOKEN_NO_REFRESH), PAST); + + expect(status).toEqual({ state: 'expired', tokens: TOKEN_NO_REFRESH }); + }); + + test('returns connected when a token can be used or refreshed', () => { + const status = getStoredMcpConnectionStatus(JSON.stringify(TOKEN_WITH_REFRESH), PAST); + + expect(status).toEqual({ state: 'connected', tokens: TOKEN_WITH_REFRESH }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/connectionStatus.ts b/packages/web/src/ee/features/chat/mcp/connectionStatus.ts new file mode 100644 index 000000000..89428b015 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/connectionStatus.ts @@ -0,0 +1,74 @@ +import type { OAuthTokens } from '@ai-sdk/mcp'; +import { decryptOAuthToken } from '@sourcebot/shared'; + +export type StoredMcpConnectionStatus = + | { state: 'connected'; tokens: OAuthTokens } + | { state: 'expired'; tokens: OAuthTokens } + | { state: 'not_connected' }; + +export function isTokenExpiredWithNoRefresh(tokens: OAuthTokens, tokensExpiresAt: Date | null): boolean { + if (tokens.refresh_token) { + return false; + } + if (!tokensExpiresAt) { + return false; + } + return new Date() > tokensExpiresAt; +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function parseStoredOAuthTokens(value: string): OAuthTokens | undefined { + const parsed = JSON.parse(value); + if (!isRecord(parsed)) { + return undefined; + } + if (typeof parsed.access_token !== 'string' || typeof parsed.token_type !== 'string') { + return undefined; + } + if (parsed.refresh_token !== undefined && typeof parsed.refresh_token !== 'string') { + return undefined; + } + if (parsed.expires_in !== undefined && typeof parsed.expires_in !== 'number') { + return undefined; + } + if (parsed.scope !== undefined && typeof parsed.scope !== 'string') { + return undefined; + } + if (parsed.id_token !== undefined && typeof parsed.id_token !== 'string') { + return undefined; + } + + return parsed as OAuthTokens; +} + +export function getStoredMcpConnectionStatus( + encryptedTokens: string | null | undefined, + tokensExpiresAt: Date | null, +): StoredMcpConnectionStatus { + if (!encryptedTokens) { + return { state: 'not_connected' }; + } + + try { + const decrypted = decryptOAuthToken(encryptedTokens); + if (!decrypted) { + return { state: 'not_connected' }; + } + + const tokens = parseStoredOAuthTokens(decrypted); + if (!tokens) { + return { state: 'not_connected' }; + } + + if (isTokenExpiredWithNoRefresh(tokens, tokensExpiresAt)) { + return { state: 'expired', tokens }; + } + + return { state: 'connected', tokens }; + } catch { + return { state: 'not_connected' }; + } +} diff --git a/packages/web/src/ee/features/chat/mcp/dcrDiscovery.test.ts b/packages/web/src/ee/features/chat/mcp/dcrDiscovery.test.ts new file mode 100644 index 000000000..194a2a815 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/dcrDiscovery.test.ts @@ -0,0 +1,217 @@ +import { describe, expect, test, vi } from 'vitest'; +import { checkMcpServerDcrSupport } from './dcrDiscovery'; + +function jsonResponse(body: unknown) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); +} + +function notFoundResponse() { + return new Response('Not found', { status: 404 }); +} + +function deferredResponse() { + let resolve!: (response: Response) => void; + const promise = new Promise((resolvePromise) => { + resolve = resolvePromise; + }); + + return { promise, resolve }; +} + +describe('checkMcpServerDcrSupport', () => { + test('returns supported when authorization server metadata advertises a registration endpoint', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.example.com/.well-known/oauth-protected-resource/mcp') { + return jsonResponse({ authorization_servers: ['https://auth.example.com'] }); + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return jsonResponse({ registration_endpoint: 'https://auth.example.com/register' }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: true, + authorizationServerUrl: 'https://auth.example.com', + registrationEndpoint: 'https://auth.example.com/register', + }); + }); + + test('returns unsupported when authorization server metadata does not advertise a registration endpoint', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.slack.com/.well-known/oauth-protected-resource') { + return jsonResponse({ authorization_servers: ['https://mcp.slack.com'] }); + } + if (url === 'https://mcp.slack.com/.well-known/oauth-authorization-server') { + return jsonResponse({ + authorization_endpoint: 'https://slack.com/oauth/v2_user/authorize', + token_endpoint: 'https://slack.com/api/oauth.v2.user.access', + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.slack.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: false, + isKnown: true, + authorizationServerUrl: 'https://mcp.slack.com', + }); + }); + + test('falls back to the resource metadata URL from a bearer challenge', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return jsonResponse({ registration_endpoint: 'https://auth.example.com/register' }); + } + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata="https://metadata.example.com/oauth-protected-resource"', + }, + }); + } + if (url === 'https://metadata.example.com/oauth-protected-resource') { + return jsonResponse({ authorization_servers: ['https://auth.example.com'] }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + const result = await checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock); + + expect(result.supportsDcr).toBe(true); + expect(result.isKnown).toBe(true); + }); + + test('ignores non-bearer authenticate challenges', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Basic realm="mcp"', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('ignores malformed bearer resource metadata URLs', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata="not a url"', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('ignores bearer resource metadata parameters without quotes', async () => { + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url.includes('/.well-known/')) { + return notFoundResponse(); + } + if (url === 'https://mcp.example.com/mcp') { + return new Response('', { + status: 401, + headers: { + 'www-authenticate': 'Bearer resource_metadata=https://metadata.example.com/oauth-protected-resource', + }, + }); + } + return notFoundResponse(); + }) as unknown as typeof fetch; + + await expect(checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock)).resolves.toEqual({ + supportsDcr: true, + isKnown: false, + authorizationServerUrl: 'https://mcp.example.com/mcp', + }); + }); + + test('starts authorization server metadata candidate requests concurrently while preserving priority', async () => { + const pathScopedOAuthMetadata = deferredResponse(); + const rootOAuthMetadata = deferredResponse(); + const pathScopedOidcMetadata = deferredResponse(); + const nestedOidcMetadata = deferredResponse(); + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = input.toString(); + if (url === 'https://mcp.example.com/.well-known/oauth-protected-resource/mcp') { + return jsonResponse({ authorization_servers: ['https://auth.example.com/tenant'] }); + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server/tenant') { + return pathScopedOAuthMetadata.promise; + } + if (url === 'https://auth.example.com/.well-known/oauth-authorization-server') { + return rootOAuthMetadata.promise; + } + if (url === 'https://auth.example.com/.well-known/openid-configuration/tenant') { + return pathScopedOidcMetadata.promise; + } + if (url === 'https://auth.example.com/tenant/.well-known/openid-configuration') { + return nestedOidcMetadata.promise; + } + return notFoundResponse(); + }); + + const resultPromise = checkMcpServerDcrSupport('https://mcp.example.com/mcp', fetchMock as unknown as typeof fetch); + await vi.waitFor(() => { + const requestedUrls = fetchMock.mock.calls.map(([input]) => input.toString()); + + expect(requestedUrls).toContain('https://auth.example.com/.well-known/oauth-authorization-server/tenant'); + expect(requestedUrls).toContain('https://auth.example.com/.well-known/oauth-authorization-server'); + expect(requestedUrls).toContain('https://auth.example.com/.well-known/openid-configuration/tenant'); + expect(requestedUrls).toContain('https://auth.example.com/tenant/.well-known/openid-configuration'); + }); + + rootOAuthMetadata.resolve(jsonResponse({ registration_endpoint: 'https://auth.example.com/register' })); + pathScopedOidcMetadata.resolve(notFoundResponse()); + nestedOidcMetadata.resolve(notFoundResponse()); + await Promise.resolve(); + + pathScopedOAuthMetadata.resolve(notFoundResponse()); + + await expect(resultPromise).resolves.toEqual({ + supportsDcr: true, + isKnown: true, + authorizationServerUrl: 'https://auth.example.com/tenant', + registrationEndpoint: 'https://auth.example.com/register', + }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/dcrDiscovery.ts b/packages/web/src/ee/features/chat/mcp/dcrDiscovery.ts new file mode 100644 index 000000000..286883d50 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/dcrDiscovery.ts @@ -0,0 +1,206 @@ +import { z } from 'zod'; + +const MCP_PROTOCOL_VERSION = '2025-11-25'; + +const protectedResourceMetadataSchema = z.object({ + authorization_servers: z.array(z.string().url()).optional(), +}).passthrough(); + +const authorizationServerMetadataSchema = z.object({ + registration_endpoint: z.string().url().optional(), +}).passthrough(); + +export interface McpServerDcrSupport { + supportsDcr: boolean; + isKnown: boolean; + authorizationServerUrl?: string; + registrationEndpoint?: string; +} + +function getMetadataHeaders() { + return { + Accept: 'application/json', + 'MCP-Protocol-Version': MCP_PROTOCOL_VERSION, + }; +} + +function buildProtectedResourceMetadataUrls(serverUrl: URL): URL[] { + const urls: URL[] = []; + const pathname = serverUrl.pathname.endsWith('/') + ? serverUrl.pathname.slice(0, -1) + : serverUrl.pathname; + + if (pathname && pathname !== '/') { + urls.push(new URL(`/.well-known/oauth-protected-resource${pathname}`, serverUrl.origin)); + } + + urls.push(new URL('/.well-known/oauth-protected-resource', serverUrl.origin)); + return urls; +} + +function buildAuthorizationServerMetadataUrls(authorizationServerUrl: URL): URL[] { + const hasPath = authorizationServerUrl.pathname !== '/'; + + if (!hasPath) { + return [ + new URL('/.well-known/oauth-authorization-server', authorizationServerUrl.origin), + new URL('/.well-known/openid-configuration', authorizationServerUrl.origin), + ]; + } + + const pathname = authorizationServerUrl.pathname.endsWith('/') + ? authorizationServerUrl.pathname.slice(0, -1) + : authorizationServerUrl.pathname; + + return [ + new URL(`/.well-known/oauth-authorization-server${pathname}`, authorizationServerUrl.origin), + new URL('/.well-known/oauth-authorization-server', authorizationServerUrl.origin), + new URL(`/.well-known/openid-configuration${pathname}`, authorizationServerUrl.origin), + new URL(`${pathname}/.well-known/openid-configuration`, authorizationServerUrl.origin), + ]; +} + +function normalizeUrlForOutput(url: URL): string { + return url.toString().replace(/\/$/, ''); +} + +function extractResourceMetadataUrl(response: Response): URL | undefined { + const header = response.headers.get('www-authenticate'); + if (!header) { + return undefined; + } + + if (!header.toLowerCase().startsWith('bearer ')) { + return undefined; + } + + const match = header.match(/resource_metadata="([^"]+)"/); + if (!match) { + return undefined; + } + + try { + return new URL(match[1]); + } catch { + return undefined; + } +} + +async function fetchJson(url: URL, fetchFn: typeof fetch): Promise { + const response = await fetchFn(url, { headers: getMetadataHeaders() }); + + if (!response.ok) { + return undefined; + } + + return response.json(); +} + +async function fetchMetadataByPriority( + urls: URL[], + fetchFn: typeof fetch, + schema: z.ZodType, +): Promise { + const metadataPromises = urls.map(async (url) => { + try { + const json = await fetchJson(url, fetchFn); + const metadata = schema.safeParse(json); + return metadata.success ? metadata.data : undefined; + } catch { + return undefined; + } + }); + + for (const metadataPromise of metadataPromises) { + const metadata = await metadataPromise; + if (metadata) { + return metadata; + } + } + + return undefined; +} + +async function discoverProtectedResourceMetadata(serverUrl: URL, fetchFn: typeof fetch) { + const challengeMetadataPromise = (async () => { + try { + const response = await fetchFn(serverUrl, { headers: getMetadataHeaders() }); + const resourceMetadataUrl = extractResourceMetadataUrl(response); + if (!resourceMetadataUrl) { + return undefined; + } + + const json = await fetchJson(resourceMetadataUrl, fetchFn); + const metadata = protectedResourceMetadataSchema.safeParse(json); + return metadata.success ? metadata.data : undefined; + } catch { + return undefined; + } + })(); + + const wellKnownMetadata = await fetchMetadataByPriority( + buildProtectedResourceMetadataUrls(serverUrl), + fetchFn, + protectedResourceMetadataSchema, + ); + if (wellKnownMetadata) { + return wellKnownMetadata; + } + + return challengeMetadataPromise; +} + +async function discoverAuthorizationServerMetadata(authorizationServerUrl: URL, fetchFn: typeof fetch) { + return fetchMetadataByPriority( + buildAuthorizationServerMetadataUrls(authorizationServerUrl), + fetchFn, + authorizationServerMetadataSchema, + ); +} + +export async function checkMcpServerDcrSupport(serverUrl: string, fetchFn: typeof fetch = fetch): Promise { + const parsedServerUrl = new URL(serverUrl); + const protectedResourceMetadata = await discoverProtectedResourceMetadata(parsedServerUrl, fetchFn); + const authorizationServerUrls = protectedResourceMetadata?.authorization_servers?.length + ? protectedResourceMetadata.authorization_servers + : [parsedServerUrl.toString()]; + + let foundAuthorizationServerMetadata = false; + let firstAuthorizationServerUrl: URL | undefined; + for (const authorizationServer of authorizationServerUrls) { + const authorizationServerUrl = new URL(authorizationServer); + firstAuthorizationServerUrl ??= authorizationServerUrl; + const authorizationServerMetadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, fetchFn); + if (!authorizationServerMetadata) { + continue; + } + + foundAuthorizationServerMetadata = true; + if (authorizationServerMetadata.registration_endpoint) { + return { + supportsDcr: true, + isKnown: true, + authorizationServerUrl: normalizeUrlForOutput(authorizationServerUrl), + registrationEndpoint: authorizationServerMetadata.registration_endpoint, + }; + } + } + + if (foundAuthorizationServerMetadata) { + return { + supportsDcr: false, + isKnown: true, + authorizationServerUrl: firstAuthorizationServerUrl + ? normalizeUrlForOutput(firstAuthorizationServerUrl) + : undefined, + }; + } + + return { + supportsDcr: true, + isKnown: false, + authorizationServerUrl: firstAuthorizationServerUrl + ? normalizeUrlForOutput(firstAuthorizationServerUrl) + : undefined, + }; +} diff --git a/packages/web/src/ee/features/chat/mcp/errors.ts b/packages/web/src/ee/features/chat/mcp/errors.ts new file mode 100644 index 000000000..12a0c79a9 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/errors.ts @@ -0,0 +1,10 @@ +import { ErrorCode } from '@/lib/errorCodes'; +import { ServiceError } from '@/lib/serviceError'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { StatusCodes } from 'http-status-codes'; + +export const oauthNotSupported = (): ServiceError => ({ + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE, +}); diff --git a/packages/web/src/ee/features/chat/mcp/externalMcpError.test.ts b/packages/web/src/ee/features/chat/mcp/externalMcpError.test.ts new file mode 100644 index 000000000..5f51433b5 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/externalMcpError.test.ts @@ -0,0 +1,66 @@ +import { describe, expect, test } from 'vitest'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; + +describe('getExternalMcpErrorLogFields', () => { + test('does not include raw error messages or response bodies', () => { + class OAuthProviderError extends Error { + statusCode = 401; + response = { + status: 401, + body: JSON.stringify({ + error: 'invalid_client', + error_description: 'client_secret=client-secret refresh_token=refresh-token', + }), + }; + } + const error = new OAuthProviderError('invalid_client client_secret=client-secret'); + + const fields = getExternalMcpErrorLogFields(error); + + expect(fields).toEqual({ + errorClass: 'OAuthProviderError', + errorName: 'Error', + oauthError: 'invalid_client', + statusCode: 401, + }); + expect(JSON.stringify(fields)).not.toContain('client-secret'); + expect(JSON.stringify(fields)).not.toContain('refresh-token'); + }); + + test('drops unsafe custom names', () => { + const fields = getExternalMcpErrorLogFields({ + name: 'client_secret=client-secret', + status: 502, + }); + + expect(fields).toEqual({ + errorClass: 'Object', + statusCode: 502, + }); + expect(JSON.stringify(fields)).not.toContain('client-secret'); + }); + + test('preserves known safe diagnostic reasons without raw messages', () => { + const fields = getExternalMcpErrorLogFields( + new Error('Incompatible auth server: does not support dynamic client registration'), + ); + + expect(fields).toEqual({ + errorClass: 'Error', + reason: 'dynamic_client_registration_unsupported', + }); + expect(JSON.stringify(fields)).not.toContain('Incompatible auth server'); + }); + + test('finds allowlisted OAuth codes anywhere in a message', () => { + const fields = getExternalMcpErrorLogFields( + new Error('Request failed at invalid_grant after token exchange'), + ); + + expect(fields).toEqual({ + errorClass: 'Error', + oauthError: 'invalid_grant', + }); + expect(JSON.stringify(fields)).not.toContain('Request failed'); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/externalMcpError.ts b/packages/web/src/ee/features/chat/mcp/externalMcpError.ts new file mode 100644 index 000000000..4894a317d --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/externalMcpError.ts @@ -0,0 +1,174 @@ +interface SafeExternalMcpErrorFields { + errorClass: string; + errorName?: string; + oauthError?: string; + reason?: string; + statusCode?: number; +} + +const OAUTH_ERROR_CODES = new Set([ + 'invalid_request', + 'invalid_client', + 'invalid_grant', + 'unauthorized_client', + 'unsupported_grant_type', + 'invalid_scope', + 'server_error', + 'temporarily_unavailable', +]); + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function safeIdentifier(value: unknown): string | undefined { + if (typeof value !== 'string') { + return undefined; + } + + if (!/^[A-Za-z0-9_.:-]{1,80}$/.test(value)) { + return undefined; + } + + return value; +} + +function numericStatus(value: unknown): number | undefined { + if (typeof value !== 'number' || !Number.isInteger(value)) { + return undefined; + } + + if (value < 100 || value > 599) { + return undefined; + } + + return value; +} + +function getStatusCode(error: unknown): number | undefined { + if (!isRecord(error)) { + return undefined; + } + + return numericStatus(error.statusCode) + ?? numericStatus(error.status) + ?? (isRecord(error.response) ? numericStatus(error.response.status) : undefined); +} + +function safeOAuthErrorCode(value: unknown): string | undefined { + const identifier = safeIdentifier(value); + if (!identifier) { + return undefined; + } + + const normalized = identifier.toLowerCase(); + return OAUTH_ERROR_CODES.has(normalized) ? normalized : undefined; +} + +function getErrorMessage(error: unknown): string | undefined { + if (error instanceof Error) { + return error.message; + } + + return isRecord(error) && typeof error.message === 'string' ? error.message : undefined; +} + +function getConstructorOAuthErrorCode(error: unknown): string | undefined { + if (!isRecord(error)) { + return undefined; + } + + const constructor = error.constructor; + if (!isRecord(constructor)) { + return undefined; + } + + return safeOAuthErrorCode(constructor.errorCode); +} + +function getBodyOAuthErrorCode(body: unknown): string | undefined { + if (typeof body !== 'string' || body.length > 4096) { + return undefined; + } + + try { + const parsed = JSON.parse(body); + return isRecord(parsed) ? safeOAuthErrorCode(parsed.error) : undefined; + } catch { + return undefined; + } +} + +function getMessageOAuthErrorCode(error: unknown): string | undefined { + const tokens = getErrorMessage(error)?.match(/\b[a-z_]{3,40}\b/g); + return tokens?.find((token) => OAUTH_ERROR_CODES.has(token)); +} + +function getOAuthErrorCode(error: unknown): string | undefined { + if (!isRecord(error)) { + return undefined; + } + + return safeOAuthErrorCode(error.error) + ?? safeOAuthErrorCode(error.code) + ?? safeOAuthErrorCode(error.errorCode) + ?? getConstructorOAuthErrorCode(error) + ?? getBodyOAuthErrorCode(error.body) + ?? (isRecord(error.response) ? getBodyOAuthErrorCode(error.response.body) : undefined) + ?? getMessageOAuthErrorCode(error); +} + +function getSafeReason(error: unknown): string | undefined { + const message = getErrorMessage(error)?.toLowerCase(); + if (!message) { + return undefined; + } + + if (message.includes('does not support dynamic client registration')) { + return 'dynamic_client_registration_unsupported'; + } + if (message.includes('does not support grant type')) { + return 'unsupported_grant_type'; + } + if (message.includes('does not support response type')) { + return 'unsupported_response_type'; + } + if (message.includes('does not support code challenge method') || message.includes('does not support s256 code challenge')) { + return 'unsupported_code_challenge_method'; + } + if (message.includes('oauth state parameter mismatch')) { + return 'oauth_state_mismatch'; + } + if (message.includes('oauth client information must be saveable') || message.includes('existing oauth client information is required')) { + return 'missing_oauth_client_information'; + } + + return undefined; +} + +/** + * Returns log-safe metadata for errors thrown by external MCP/OAuth libraries. + * + * Do not log raw error objects, messages, stacks, response bodies, request bodies, + * or causes from these boundaries. A malicious or misconfigured provider can echo + * client secrets or tokens into OAuth error bodies. + */ +export function getExternalMcpErrorLogFields(error: unknown): SafeExternalMcpErrorFields { + const errorClass = error instanceof Error + ? safeIdentifier(error.constructor.name) ?? 'Error' + : safeIdentifier(isRecord(error) ? error.constructor?.name : undefined) ?? 'UnknownExternalMcpError'; + const errorName = error instanceof Error + ? safeIdentifier(error.name) + : safeIdentifier(isRecord(error) ? error.name : undefined); + const oauthError = getOAuthErrorCode(error); + const reason = getSafeReason(error); + const statusCode = getStatusCode(error); + + return { + errorClass, + ...(errorName && errorName !== errorClass ? { errorName } : {}), + ...(oauthError ? { oauthError } : {}), + ...(reason ? { reason } : {}), + ...(statusCode ? { statusCode } : {}), + }; +} diff --git a/packages/web/src/ee/features/chat/mcp/hooks/useConnectMcp.ts b/packages/web/src/ee/features/chat/mcp/hooks/useConnectMcp.ts new file mode 100644 index 000000000..6cb452c65 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/hooks/useConnectMcp.ts @@ -0,0 +1,43 @@ +'use client'; + +import { useState } from 'react'; +import { useToast } from '@/components/hooks/use-toast'; +import { useQueryClient } from '@tanstack/react-query'; +import { connectMcpToAsk } from '@/app/api/(client)/client'; +import { invalidateMcpConfigurationQueries } from '@/ee/features/chat/mcp/queryKeys'; +import { isServiceError } from '@/lib/utils'; + +interface UseConnectMcpOptions { + returnTo?: string; +} + +export function useConnectMcp(options?: UseConnectMcpOptions) { + const [loadingServerId, setLoadingServerId] = useState(null); + const { toast } = useToast(); + const queryClient = useQueryClient(); + + const connect = async (serverId: string) => { + setLoadingServerId(serverId); + const result = await connectMcpToAsk({ serverId, returnTo: options?.returnTo }); + + if (isServiceError(result)) { + toast({ + description: `Failed to connect connector. ${result.message}`, + }); + setLoadingServerId(null); + return; + } + + if (result.authorizationUrl) { + window.location.href = result.authorizationUrl; + } else { + toast({ + description: 'Connector is already connected.', + }); + await invalidateMcpConfigurationQueries(queryClient); + setLoadingServerId(null); + } + }; + + return { connect, loadingServerId }; +} diff --git a/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts b/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts new file mode 100644 index 000000000..b86eb66b5 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts @@ -0,0 +1,70 @@ +'use client'; + +import { useEffect, useMemo, useRef } from 'react'; +import { useQuery, useQueryClient } from '@tanstack/react-query'; +import { getMcpServerTools } from '@/app/api/(client)/client'; +import { isServiceError } from '@/lib/utils'; +import { mcpQueryKeys } from '@/ee/features/chat/mcp/queryKeys'; +import type { ServerToolsEntry } from '@/ee/features/chat/mcp/types'; + +const EMPTY_TOOL_ENTRIES: ServerToolsEntry[] = []; + +export function useMcpToolMetadata(isOAuthAvailable: boolean, connectedServerCount: number) { + const queryClient = useQueryClient(); + const lastAuthFailureInvalidatedAtRef = useRef(0); + const { + data: toolEntries = EMPTY_TOOL_ENTRIES, + isLoading: isToolsLoading, + isError: isToolsError, + refetch: refetchTools, + dataUpdatedAt: toolsDataUpdatedAt, + } = useQuery({ + queryKey: mcpQueryKeys.tools, + queryFn: async () => { + const result = await getMcpServerTools(); + if (isServiceError(result)) { + throw new Error("Failed to load connector tools"); + } + if (!Array.isArray(result)) { + throw new Error("Unexpected response from connector tools endpoint"); + } + return result; + }, + enabled: isOAuthAvailable && connectedServerCount > 0, + staleTime: 5 * 60 * 1000, + gcTime: 30 * 60 * 1000, + refetchOnWindowFocus: false, + }); + + const toolsByServerId = useMemo(() => { + const map = new Map(); + for (const entry of toolEntries) { + map.set(entry.serverId, entry); + } + return map; + }, [toolEntries]); + + useEffect(() => { + if (toolsDataUpdatedAt === 0) { + return; + } + if (lastAuthFailureInvalidatedAtRef.current === toolsDataUpdatedAt) { + return; + } + if (!toolEntries.some((entry) => entry.status === 'error' && entry.reason === 'auth_failed')) { + return; + } + + lastAuthFailureInvalidatedAtRef.current = toolsDataUpdatedAt; + void queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }); + void queryClient.invalidateQueries({ queryKey: mcpQueryKeys.configuration }); + }, [queryClient, toolEntries, toolsDataUpdatedAt]); + + return { + toolEntries, + toolsByServerId, + isToolsLoading, + isToolsError, + refetchTools, + }; +} diff --git a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts new file mode 100644 index 000000000..4333449bc --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts @@ -0,0 +1,108 @@ +import { expect, test, describe, vi } from 'vitest'; +import { prisma } from '@/__mocks__/prisma'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +// --- Mocks --- + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + env: { AUTH_URL: 'https://sourcebot.example.com' }, + decryptOAuthToken: vi.fn((s: string) => s), +})); + +vi.mock('server-only', () => ({ default: vi.fn() })); + +vi.mock('@/features/mcp/prismaOAuthClientProvider', () => ({ + PrismaOAuthClientProvider: vi.fn(), +})); + +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ + StreamableHTTPClientTransport: vi.fn(), +})); + +// Import after mocks are set up +const { getConnectedMcpClients } = await import('./mcpClientFactory'); + +// --- Helpers --- + +const PAST = new Date('2020-01-01'); +const TOKEN_NO_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer' }; +const TOKEN_WITH_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer', refresh_token: 'ref' }; + +function makeUserServer(overrides: { + tokens?: OAuthTokens; + tokensExpiresAt?: Date | null; + orgId?: number; +}) { + return { + serverId: 'srv-1', + userId: 'user-1', + tokens: JSON.stringify(overrides.tokens ?? TOKEN_NO_REFRESH), + tokensExpiresAt: overrides.tokensExpiresAt ?? null, + server: { + orgId: overrides.orgId ?? 1, + name: 'MyServer', + sanitizedName: 'myserver', + serverUrl: 'https://example.com/mcp', + }, + }; +} + +// --- getConnectedMcpClients --- + +describe('getConnectedMcpClients', () => { + test('skips server when access token expired and no refresh token', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_NO_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(0); + }); + + test('includes server when refresh_token present even if access token expired', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_WITH_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(1); + }); + + test('includes server when tokensExpiresAt is null', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokensExpiresAt: null }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(1); + }); + + test('skips server belonging to a different org', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ orgId: 999 }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result).toHaveLength(0); + }); + + test('returns server metadata from the user MCP server row', async () => { + prisma.userMcpServer.findMany.mockResolvedValue([ + makeUserServer({ tokens: TOKEN_WITH_REFRESH }), + ] as never); + + const result = await getConnectedMcpClients(prisma, 'user-1', 1); + expect(result[0]).toMatchObject({ + serverId: 'srv-1', + serverName: 'MyServer', + sanitizedName: 'myserver', + serverUrl: 'https://example.com/mcp', + }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts new file mode 100644 index 000000000..b74d710ec --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts @@ -0,0 +1,100 @@ +import { createLogger, env } from '@sourcebot/shared'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import type { PrismaClient } from '@sourcebot/db'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; +import { getStoredMcpConnectionStatus } from './connectionStatus'; + +const logger = createLogger('mcp-client-factory'); + +export interface McpToolSet { + serverId: string; + serverName: string; + sanitizedName: string; + serverUrl: string; + transport: StreamableHTTPClientTransport; +} + +/** + * Creates authenticated transports for all external MCP servers the user has valid credentials for. + * Skips servers with clearly expired tokens and no refresh token. + * Does NOT connect — connection is deferred to createMCPClient. + */ +export async function getConnectedMcpClients(prisma: PrismaClient, userId: string, orgId: number): Promise { + const userServers = await prisma.userMcpServer.findMany({ + where: { + userId, + tokens: { not: null }, + server: { + orgId, + clientInfo: { not: null }, + }, + }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + server: { + select: { + orgId: true, + name: true, + sanitizedName: true, + serverUrl: true, + }, + }, + }, + }); + + const clients: McpToolSet[] = []; + + for (const userServer of userServers) { + // Skip servers that don't belong to the current org. + if (userServer.server.orgId !== orgId) { + continue; + } + + const serverName = userServer.server.name; + + try { + const connectionStatus = getStoredMcpConnectionStatus(userServer.tokens, userServer.tokensExpiresAt); + if (connectionStatus.state === 'not_connected') { + logger.warn(`Could not decrypt tokens for MCP server ${serverName}, skipping.`); + continue; + } + + if (connectionStatus.state === 'expired') { + logger.warn(`Access token for MCP server ${serverName} is expired and has no refresh token. User ${userId} needs to re-authorize.`); + continue; + } + + const provider = new PrismaOAuthClientProvider({ + prisma, + serverId: userServer.serverId, + orgId, + userId, + callbackUrl: `${env.AUTH_URL}/api/ee/askmcp/callback`, + }); + + const transport = new StreamableHTTPClientTransport( + new URL(userServer.server.serverUrl), + { authProvider: provider }, + ); + + clients.push({ + serverId: userServer.serverId, + serverName, + sanitizedName: userServer.server.sanitizedName, + serverUrl: userServer.server.serverUrl, + transport, + }); + } catch (error) { + logger.error('Failed to prepare MCP server transport.', { + serverId: userServer.serverId, + sanitizedName: userServer.server.sanitizedName, + error: getExternalMcpErrorLogFields(error), + }); + } + } + + return clients; +} diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.test.ts b/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.test.ts new file mode 100644 index 000000000..292b7ae7d --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.test.ts @@ -0,0 +1,166 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import type { PrismaClient } from '@sourcebot/db'; + +const createMCPClient = vi.hoisted(() => vi.fn()); +const getConnectedMcpClients = vi.hoisted(() => vi.fn()); +const loggerWarn = vi.hoisted(() => vi.fn()); + +vi.mock('@ai-sdk/mcp', () => ({ + createMCPClient, +})); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => ({ + warn: loggerWarn, + error: vi.fn(), + info: vi.fn(), + debug: vi.fn(), + }), + env: { + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 60000, + }, +})); + +vi.mock('./mcpClientFactory', () => ({ + getConnectedMcpClients, +})); + +const { getMcpToolMetadata } = await import('./mcpToolMetadata'); + +function makeConnectedClient(serverId = 'server-1') { + return { + serverId, + serverName: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://linear.example/mcp', + transport: { + close: vi.fn().mockResolvedValue(undefined), + }, + }; +} + +beforeEach(() => { + createMCPClient.mockReset(); + getConnectedMcpClients.mockReset(); + loggerWarn.mockReset(); +}); + +describe('getMcpToolMetadata', () => { + test('returns sanitized tool summaries for connected servers', async () => { + const connectedClient = makeConnectedClient(); + const mcpClient = { + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'lookup', + title: 'Lookup', + description: 'Find issues\nquickly', + annotations: { + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + unknownHint: true, + }, + inputSchema: { type: 'object' }, + }, + ], + }), + close: vi.fn().mockResolvedValue(undefined), + }; + getConnectedMcpClients.mockResolvedValue([connectedClient]); + createMCPClient.mockResolvedValue(mcpClient); + + const result = await getMcpToolMetadata({} as PrismaClient, 'user-1', 1); + + expect(result).toEqual([ + { + status: 'available', + serverId: 'server-1', + tools: [ + { + name: 'lookup', + title: 'Lookup', + description: 'Find alert(1) issues quickly', + annotations: { + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + }, + }, + ], + }, + ]); + expect(mcpClient.close).toHaveBeenCalledTimes(1); + expect(connectedClient.transport.close).toHaveBeenCalledTimes(1); + }); + + test('truncates very large tool lists', async () => { + const connectedClient = makeConnectedClient(); + const tools = Array.from({ length: 201 }, (_, index) => ({ + name: `tool-${index}`, + description: 'x'.repeat(600), + inputSchema: { type: 'object' }, + })); + const mcpClient = { + listTools: vi.fn().mockResolvedValue({ tools }), + close: vi.fn().mockResolvedValue(undefined), + }; + getConnectedMcpClients.mockResolvedValue([connectedClient]); + createMCPClient.mockResolvedValue(mcpClient); + + const result = await getMcpToolMetadata({} as PrismaClient, 'user-1', 1); + const entry = result[0]; + + expect(entry.status).toBe('available'); + if (entry.status === 'available') { + expect(entry.tools).toHaveLength(200); + expect(entry.truncated).toBe(true); + expect(entry.tools[0].description).toHaveLength(500); + } + }); + + test('does not mark the list truncated when only text fields are shortened', async () => { + const connectedClient = makeConnectedClient(); + const mcpClient = { + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'tool', + description: 'x'.repeat(600), + inputSchema: { type: 'object' }, + }, + ], + }), + close: vi.fn().mockResolvedValue(undefined), + }; + getConnectedMcpClients.mockResolvedValue([connectedClient]); + createMCPClient.mockResolvedValue(mcpClient); + + const result = await getMcpToolMetadata({} as PrismaClient, 'user-1', 1); + const entry = result[0]; + + expect(entry.status).toBe('available'); + if (entry.status === 'available') { + expect(entry.truncated).toBeUndefined(); + expect(entry.tools[0].description).toHaveLength(500); + } + }); + + test('maps safe auth failures without throwing the whole response', async () => { + const connectedClient = makeConnectedClient(); + getConnectedMcpClients.mockResolvedValue([connectedClient]); + createMCPClient.mockRejectedValue(Object.assign(new Error('unauthorized'), { statusCode: 401 })); + + const result = await getMcpToolMetadata({} as PrismaClient, 'user-1', 1); + + expect(result).toEqual([ + { + status: 'error', + serverId: 'server-1', + reason: 'auth_failed', + }, + ]); + expect(connectedClient.transport.close).toHaveBeenCalledTimes(1); + expect(loggerWarn).toHaveBeenCalled(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.ts b/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.ts new file mode 100644 index 000000000..3a43298ba --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolMetadata.ts @@ -0,0 +1,269 @@ +import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; +import { createLogger, env } from '@sourcebot/shared'; +import type { PrismaClient } from '@sourcebot/db'; +import { getConnectedMcpClients, type McpToolSet } from './mcpClientFactory'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; +import type { + GetMcpToolsResponse, + ServerToolsEntry, + ToolMetadataErrorReason, + ToolSummary, +} from './types'; + +const logger = createLogger('mcp-tool-metadata'); + +const MCP_TOOL_METADATA_FETCH_CONCURRENCY = 4; +const MCP_TOOL_METADATA_TIMEOUT_MS = Math.min(env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS, 10000); +const MCP_TOOL_METADATA_MAX_TOOLS = 200; +const MCP_TOOL_METADATA_MAX_NAME_LENGTH = 128; +const MCP_TOOL_METADATA_MAX_TITLE_LENGTH = 160; +const MCP_TOOL_METADATA_MAX_DESCRIPTION_LENGTH = 500; + +class ToolMetadataTimeoutError extends Error { + constructor() { + super(`MCP tool metadata fetch timed out after ${MCP_TOOL_METADATA_TIMEOUT_MS}ms`); + this.name = 'ToolMetadataTimeoutError'; + } +} + +type ListToolsResult = Awaited>; +type ToolDefinition = ListToolsResult['tools'][number]; + +function removeControlCharacters(value: string): string { + return value.replace(/[\u0000-\u001F\u007F-\u009F]/g, ' '); +} + +function removeHtmlTags(value: string): string { + return value.replace(/<[^>]*>/g, ''); +} + +function normalizeWhitespace(value: string): string { + return value.replace(/\s+/g, ' ').trim(); +} + +function sanitizeText(value: unknown, maxLength: number): string | undefined { + if (typeof value !== 'string') { + return undefined; + } + + const sanitized = normalizeWhitespace(removeControlCharacters(removeHtmlTags(value))); + if (!sanitized) { + return undefined; + } + + return sanitized.length > maxLength ? sanitized.slice(0, maxLength).trimEnd() : sanitized; +} + +function sanitizeAnnotations(tool: ToolDefinition): ToolSummary['annotations'] { + const annotations = tool.annotations; + if (!annotations) { + return undefined; + } + + const sanitized: ToolSummary['annotations'] = {}; + if (typeof annotations.readOnlyHint === 'boolean') { + sanitized.readOnlyHint = annotations.readOnlyHint; + } + if (typeof annotations.destructiveHint === 'boolean') { + sanitized.destructiveHint = annotations.destructiveHint; + } + if (typeof annotations.idempotentHint === 'boolean') { + sanitized.idempotentHint = annotations.idempotentHint; + } + + return Object.keys(sanitized).length > 0 ? sanitized : undefined; +} + +function sanitizeTool(tool: ToolDefinition): ToolSummary { + const toolWithOptionalTitle = tool as ToolDefinition & { + title?: unknown; + annotations?: ToolDefinition['annotations'] & { title?: unknown }; + }; + const name = sanitizeText(tool.name, MCP_TOOL_METADATA_MAX_NAME_LENGTH) ?? 'unnamed_tool'; + const title = sanitizeText( + toolWithOptionalTitle.title ?? toolWithOptionalTitle.annotations?.title, + MCP_TOOL_METADATA_MAX_TITLE_LENGTH, + ); + const description = sanitizeText(tool.description, MCP_TOOL_METADATA_MAX_DESCRIPTION_LENGTH); + const annotations = sanitizeAnnotations(tool); + + return { + name, + ...(title ? { title } : {}), + ...(description ? { description } : {}), + ...(annotations ? { annotations } : {}), + }; +} + +async function withTimeout( + promise: Promise, + onTimeout: () => Promise, + onLateResolve?: (value: T) => Promise, +): Promise { + promise.catch(() => undefined); + + return new Promise((resolve, reject) => { + let didTimeout = false; + const timeoutId = setTimeout(() => { + didTimeout = true; + onTimeout().catch(() => undefined); + reject(new ToolMetadataTimeoutError()); + }, MCP_TOOL_METADATA_TIMEOUT_MS); + + promise.then( + (value) => { + if (didTimeout) { + onLateResolve?.(value).catch(() => undefined); + return; + } + clearTimeout(timeoutId); + resolve(value); + }, + (error) => { + if (didTimeout) { + return; + } + clearTimeout(timeoutId); + reject(error); + }, + ); + }); +} + +function getToolMetadataErrorReason(error: unknown): ToolMetadataErrorReason { + if (error instanceof ToolMetadataTimeoutError) { + return 'timeout'; + } + + const fields = getExternalMcpErrorLogFields(error); + if ( + fields.oauthError === 'invalid_grant' || + fields.oauthError === 'invalid_client' || + fields.oauthError === 'unauthorized_client' || + fields.statusCode === 401 || + fields.statusCode === 403 + ) { + return 'auth_failed'; + } + + if ( + fields.reason === 'dynamic_client_registration_unsupported' || + fields.reason === 'unsupported_grant_type' || + fields.reason === 'unsupported_response_type' || + fields.reason === 'unsupported_code_challenge_method' || + fields.statusCode === 404 || + fields.statusCode === 405 + ) { + return 'unsupported'; + } + + const message = error instanceof Error ? error.message.toLowerCase() : ''; + if (message.includes('does not support tools') || message.includes('does not support http transport')) { + return 'unsupported'; + } + + if (fields.statusCode || fields.errorClass === 'TypeError') { + return 'connection_failed'; + } + + return 'unknown'; +} + +async function cleanupMcpClient(mcpClient: MCPClient | undefined, { transport }: McpToolSet) { + // Timeout handlers close the transport immediately to interrupt the in-flight request. + // This final cleanup may close it again; transports are expected to tolerate that. + await Promise.allSettled([ + mcpClient?.close(), + transport.close(), + ]); +} + +async function fetchToolsForClient(client: McpToolSet): Promise { + let mcpClient: MCPClient | undefined; + + try { + mcpClient = await withTimeout( + createMCPClient({ transport: client.transport }), + async () => { + await client.transport.close(); + }, + async (lateClient) => { + await lateClient.close(); + }, + ); + + const result = await withTimeout( + mcpClient.listTools(), + async () => { + await client.transport.close(); + }, + ); + + const tools = result.tools.slice(0, MCP_TOOL_METADATA_MAX_TOOLS).map(sanitizeTool); + const nextCursor = (result as ListToolsResult & { nextCursor?: unknown }).nextCursor; + const truncated = result.tools.length > MCP_TOOL_METADATA_MAX_TOOLS || typeof nextCursor === 'string'; + + return { + status: 'available', + serverId: client.serverId, + tools, + ...(truncated ? { truncated } : {}), + }; + } catch (error) { + const reason = getToolMetadataErrorReason(error); + logger.warn('Failed to load MCP tool metadata.', { + serverId: client.serverId, + sanitizedName: client.sanitizedName, + reason, + error: getExternalMcpErrorLogFields(error), + }); + + return { + status: 'error', + serverId: client.serverId, + reason, + }; + } finally { + await cleanupMcpClient(mcpClient, client); + } +} + +async function fetchToolsBatch(clients: McpToolSet[]): Promise { + const settled = await Promise.allSettled(clients.map((client) => fetchToolsForClient(client))); + return settled.map((result, index) => { + if (result.status === 'fulfilled') { + return result.value; + } + + // Defensive: fetchToolsForClient should catch per-server failures and resolve. + const client = clients[index]; + logger.warn('Failed to load MCP tool metadata.', { + serverId: client.serverId, + sanitizedName: client.sanitizedName, + reason: 'unknown' satisfies ToolMetadataErrorReason, + error: getExternalMcpErrorLogFields(result.reason), + }); + + return { + status: 'error', + serverId: client.serverId, + reason: 'unknown', + }; + }); +} + +export async function getMcpToolMetadata( + prisma: PrismaClient, + userId: string, + orgId: number, +): Promise { + const clients = await getConnectedMcpClients(prisma, userId, orgId); + const results: ServerToolsEntry[] = []; + + for (let index = 0; index < clients.length; index += MCP_TOOL_METADATA_FETCH_CONCURRENCY) { + const batch = clients.slice(index, index + MCP_TOOL_METADATA_FETCH_CONCURRENCY); + results.push(...await fetchToolsBatch(batch)); + } + + return results; +} diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.test.ts b/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.test.ts new file mode 100644 index 000000000..20918f066 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.test.ts @@ -0,0 +1,185 @@ +import { expect, test, describe } from 'vitest'; +import { buildMcpToolRegistry, searchMcpTools, McpToolRegistryEntry } from './mcpToolRegistry'; + +// Helper to create a mock tool record matching the MCPClient['tools'] return type. +function createToolRecord(tools: Record) { + const record: Record = {}; + for (const [name, tool] of Object.entries(tools)) { + record[name] = { + description: tool.description, + execute: tool.execute ?? (() => {}), + inputSchema: {}, + }; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return record as any; +} + +describe('buildMcpToolRegistry', () => { + test('extracts serverName from namespaced tool name', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: 'List issues' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List issues', serverName: 'linear' }, + ]); + }); + + test('handles underscores in server name', () => { + const tools = createToolRecord({ + 'mcp_my_server__get_data': { description: 'Get data' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe('my_server'); + }); + + test('defaults missing description to empty string', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: undefined }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].description).toBe(''); + }); + + test('non-matching tool name yields empty serverName', () => { + const tools = createToolRecord({ + 'some_random_tool': { description: 'A tool' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe(''); + }); + + test('empty tools record returns empty array', () => { + const registry = buildMcpToolRegistry(createToolRecord({})); + + expect(registry).toEqual([]); + }); +}); + +describe('searchMcpTools', () => { + // Shared registry for most tests. + const registry: McpToolRegistryEntry[] = [ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + { name: 'mcp_linear__create_issue', description: 'Create a new issue', serverName: 'linear' }, + { name: 'mcp_linear__update_issue', description: 'Update an existing issue', serverName: 'linear' }, + { name: 'mcp_github__search_repos', description: 'Search repositories on GitHub', serverName: 'github' }, + { name: 'mcp_pg__run_query', description: 'Run a database query', serverName: 'pg' }, + { name: 'mcp_slack__send_message', description: 'Send a message to a Slack channel', serverName: 'slack' }, + { name: 'mcp_jira__create_ticket', description: 'Create a new Jira ticket', serverName: 'jira' }, + ]; + + test('exact name match returns single result', () => { + const results = searchMcpTools('mcp_linear__list_issues', registry); + + expect(results).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + ]); + }); + + test('token matching on tool name', () => { + const results = searchMcpTools('list issues', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_linear__list_issues'); + }); + + test('synonym expansion: "find" matches tools with "list"', () => { + const results = searchMcpTools('find issues', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); + + test('synonym expansion: "add" matches tools with "create"', () => { + const results = searchMcpTools('add ticket', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_jira__create_ticket'); + }); + + test('reverse expansion: canonical "list" expands to synonyms', () => { + // "list" is canonical and expands to "find", "get", "fetch", "search", etc. + const results = searchMcpTools('list repos', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + // "search_repos" should match because "list" expands to "search" + expect(names).toContain('mcp_github__search_repos'); + }); + + test('higher-scoring entries come first', () => { + // "create issue" should score higher for create_issue than for list_issues + const results = searchMcpTools('create issue', registry); + + expect(results.length).toBeGreaterThan(1); + // The first result should be the one that matches both tokens + expect(results[0].name).toBe('mcp_linear__create_issue'); + }); + + test('topK limits results', () => { + const results = searchMcpTools('issue', registry, 2); + + expect(results.length).toBeLessThanOrEqual(2); + }); + + test('default topK is 5', () => { + // All 7 entries match "mcp" as a substring, but we need tokens > 2 chars + // Use a query that matches many entries + const largeRegistry: McpToolRegistryEntry[] = Array.from({ length: 10 }, (_, i) => ({ + name: `mcp_server__tool_${i}`, + description: `Tool number ${i} for testing`, + serverName: 'server', + })); + + const results = searchMcpTools('tool testing', largeRegistry); + + expect(results.length).toBeLessThanOrEqual(5); + }); + + test('short/empty query fallback returns first topK entries', () => { + // "do it" — all tokens are <= 2 chars after filtering + const results = searchMcpTools('do it', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('empty string query fallback returns first topK entries', () => { + const results = searchMcpTools('', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('returns empty array when no tokens match', () => { + const results = searchMcpTools('xyznonexistent', registry); + + expect(results).toEqual([]); + }); + + test('search matches in description, not just name', () => { + const results = searchMcpTools('database', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_pg__run_query'); + }); + + test('tokens shorter than 3 chars are filtered out', () => { + // "do a list" → only "list" survives (length > 2) + const results = searchMcpTools('do a list', registry); + + expect(results.length).toBeGreaterThan(0); + // Should still find results via the "list" token + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.ts b/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.ts new file mode 100644 index 000000000..431710e9e --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolRegistry.ts @@ -0,0 +1,99 @@ +import type { MCPClient } from '@ai-sdk/mcp'; + +export interface McpToolRegistryEntry { + name: string; + description: string; + serverName: string; +} + +type McpToolRecord = Awaited>; + +// Synonym map for common action words. Expands query tokens so that e.g. +// "find tickets" matches a tool named "list_issues". +// Module-level constant — built once at server startup, never re-created. +const SYNONYM_MAP: Record = { + list: ['find', 'get', 'fetch', 'retrieve', 'search', 'show', 'query', 'read'], + create: ['make', 'add', 'post', 'open', 'new', 'submit', 'write'], + update: ['edit', 'modify', 'change', 'patch', 'set'], + delete: ['remove', 'destroy', 'archive', 'close'], + send: ['post', 'publish', 'notify', 'message'], + issue: ['ticket', 'bug', 'task', 'item', 'work'], + comment: ['note', 'reply', 'respond'], + user: ['member', 'person', 'assignee'], + project: ['repo', 'repository', 'workspace'], +}; + +// Reverse lookup: synonym → canonical token. Built once from SYNONYM_MAP. +const REVERSE_SYNONYMS: Record = {}; +for (const [canonical, synonyms] of Object.entries(SYNONYM_MAP)) { + for (const synonym of synonyms) { + REVERSE_SYNONYMS[synonym] = canonical; + } +} + +function expandTokens(tokens: string[]): string[] { + const expanded = new Set(tokens); + for (const token of tokens) { + const canonical = REVERSE_SYNONYMS[token]; + if (canonical) { + expanded.add(canonical); + } + const synonyms = SYNONYM_MAP[token]; + if (synonyms) { + for (const s of synonyms) { + expanded.add(s); + } + } + } + return Array.from(expanded); +} + +export function buildMcpToolRegistry(tools: McpToolRecord): McpToolRegistryEntry[] { + return Object.entries(tools).map(([name, tool]) => { + const match = name.match(/^mcp_(.+?)__/); + const serverName = match ? match[1] : ''; + return { + name, + description: tool.description ?? '', + serverName, + }; + }); +} + +export function searchMcpTools( + query: string, + registry: McpToolRegistryEntry[], + topK = 5, +): McpToolRegistryEntry[] { + // Fast path: if the query is an exact tool name, return it directly. + const exactMatch = registry.find(e => e.name === query); + if (exactMatch) { + return [exactMatch]; + } + + const rawTokens = query + .toLowerCase() + .split(/\W+/) + .filter(t => t.length > 2); + + // If no meaningful tokens remain (e.g. query is "do it" — all tokens <= 2 chars), + // fall back to returning the first topK tools rather than returning nothing. + // We could potentially return nothing or return another tool that will help search better + // in the future. + if (rawTokens.length === 0) { + return registry.slice(0, topK); + } + + const tokens = expandTokens(rawTokens); + + return registry + .map(entry => { + const haystack = `${entry.name} ${entry.description}`.toLowerCase(); + const score = tokens.filter(t => haystack.includes(t)).length; + return { entry, score }; + }) + .filter(({ score }) => score > 0) + .sort((a, b) => b.score - a.score) + .slice(0, topK) + .map(({ entry }) => entry); +} \ No newline at end of file diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolSets.test.ts b/packages/web/src/ee/features/chat/mcp/mcpToolSets.test.ts new file mode 100644 index 000000000..6205ca844 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolSets.test.ts @@ -0,0 +1,471 @@ +import { expect, test, describe, vi, beforeEach } from 'vitest'; +import { Prisma } from '@sourcebot/db'; +import type { McpToolSet } from './mcpClientFactory'; + +// --- Mocks --- + +const mockCreateMCPClient = vi.fn(); +const mockLogger = vi.hoisted(() => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), +})); +const mockToolCallCountUpsert = vi.hoisted(() => vi.fn()); +const mockToolCallCountUpdate = vi.hoisted(() => vi.fn()); +const mockCaptureEvent = vi.hoisted(() => vi.fn()); + +vi.mock('@ai-sdk/mcp', () => ({ + createMCPClient: (...args: unknown[]) => mockCreateMCPClient(...args), +})); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => mockLogger, + env: { + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, +})); + +vi.mock('@/prisma', () => ({ + __unsafePrisma: { + mcpServerToolCallCount: { + upsert: mockToolCallCountUpsert, + update: mockToolCallCountUpdate, + }, + }, +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: mockCaptureEvent, +})); + +vi.mock('ai', () => ({ + jsonSchema: vi.fn((schema: unknown, opts: unknown) => ({ schema, ...(opts as object) })), +})); + +// --- Helpers --- + +interface MockToolDef { + name: string; + description?: string; + inputSchema?: Record; + annotations?: Record; +} + +function createMockMcpClient(toolDefs: MockToolDef[]) { + const toolRecord: Record; description: string | undefined; inputSchema: unknown }> = {}; + for (const def of toolDefs) { + toolRecord[def.name] = { + execute: vi.fn().mockResolvedValue({ content: [{ type: 'text', text: 'result' }] }), + description: def.description, + inputSchema: def.inputSchema ?? {}, + }; + } + + return { + listTools: vi.fn().mockResolvedValue({ tools: toolDefs }), + toolsFromDefinitions: vi.fn().mockReturnValue(toolRecord), + close: vi.fn().mockResolvedValue(undefined), + tools: vi.fn().mockResolvedValue(toolRecord), + }; +} + +function createMockClient(overrides: Partial & { serverName: string }): McpToolSet { + return { + serverId: 'server-id', + sanitizedName: overrides.serverName.toLowerCase(), + serverUrl: `https://${overrides.serverName.toLowerCase()}.example.com/mcp`, + transport: {} as McpToolSet['transport'], + ...overrides, + }; +} + +// --- Tests --- + +// Import after mocks are set up +const { getMcpTools } = await import('./mcpToolSets'); + +beforeEach(() => { + vi.clearAllMocks(); + mockToolCallCountUpsert.mockResolvedValue({}); + mockToolCallCountUpdate.mockResolvedValue({}); + mockCaptureEvent.mockResolvedValue(undefined); +}); + +describe('getMcpTools', () => { + test('single server with single tool produces correctly namespaced key', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + expect(result.failedServers).toEqual([]); + }); + + test('multiple servers produce tools with distinct prefixes', async () => { + const linearClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + const githubClient = createMockMcpClient([ + { name: 'search_repos', description: 'Search repos' }, + ]); + + mockCreateMCPClient + .mockResolvedValueOnce(linearClient) + .mockResolvedValueOnce(githubClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + createMockClient({ serverName: 'GitHub' }), + ]); + + const toolNames = Object.keys(result.tools); + expect(toolNames).toContain('mcp_linear__list_issues'); + expect(toolNames).toContain('mcp_github__search_repos'); + }); + + test('read-only tool does NOT get needsApproval', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues', annotations: { readOnlyHint: true } }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__list_issues']; + expect(tool).toBeDefined(); + expect('needsApproval' in tool).toBe(false); + }); + + test('non-read-only tool gets needsApproval: true', async () => { + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + expect(tool).toBeDefined(); + expect(tool).toHaveProperty('needsApproval', true); + }); + + test('failed server connection adds to failedServers array', async () => { + const error = new Error('Connection refused client_secret=client-secret access_token=access-token'); + Object.assign(error, { + response: { + status: 502, + body: 'client_secret=client-secret access_token=access-token', + }, + }); + mockCreateMCPClient.mockRejectedValue(error); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual([]); + expect(mockLogger.error).toHaveBeenCalledWith('Failed to get tools from MCP server.', { + serverId: 'server-id', + sanitizedName: 'brokenserver', + error: { + errorClass: 'Error', + statusCode: 502, + }, + }); + expect(JSON.stringify(mockLogger.error.mock.calls)).not.toContain('client-secret'); + expect(JSON.stringify(mockLogger.error.mock.calls)).not.toContain('access-token'); + }); + + test('failed server does not prevent other servers from working', async () => { + const goodClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + + mockCreateMCPClient + .mockRejectedValueOnce(new Error('Connection refused')) + .mockResolvedValueOnce(goodClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + createMockClient({ serverName: 'Linear' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + }); + + test('generates favicon URL from server URL origin', async () => { + const mockClient = createMockMcpClient([ + { name: 'tool', description: 'A tool' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear', serverUrl: 'https://api.linear.app/mcp' }), + ]); + + expect(result.serverFaviconUrls['linear']).toBe( + 'https://www.google.com/s2/favicons?domain=https://api.linear.app&sz=32' + ); + }); + + test('cleanup function calls close on all clients', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + await result.cleanup(); + + expect(client1.close).toHaveBeenCalledOnce(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('cleanup handles errors in close gracefully', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + client1.close.mockRejectedValue(new Error('Close failed')); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + // Should not throw + await expect(result.cleanup()).resolves.toBeUndefined(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('empty clients array returns empty result', async () => { + const result = await getMcpTools([]); + + expect(result.tools).toEqual({}); + expect(result.failedServers).toEqual([]); + expect(result.serverFaviconUrls).toEqual({}); + expect(typeof result.cleanup).toBe('function'); + }); + + test('tool schema validation rejects invalid input', async () => { + const mockClient = createMockMcpClient([ + { + name: 'create_issue', + description: 'Create issue', + inputSchema: { + type: 'object', + properties: { title: { type: 'string' } }, + }, + }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + // The inputSchema should have a validate function from our jsonSchema mock + const schema = tool.inputSchema as { validate?: (value: unknown) => Promise<{ success: boolean; error?: Error }> }; + expect(schema.validate).toBeDefined(); + + if (schema.validate) { + // Valid input + const validResult = await schema.validate({ title: 'My Issue' }); + expect(validResult.success).toBe(true); + + // Invalid input (extra property not allowed because additionalProperties: false) + const invalidResult = await schema.validate({ title: 'My Issue', bogus: 'field' }); + expect(invalidResult.success).toBe(false); + } + }); + + test('tool execute wrapper propagates non-timeout errors', async () => { + const originalError = new Error('External API failed'); + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + // Override the execute to reject + const toolRecord = mockClient.toolsFromDefinitions(); + toolRecord['create_issue'].execute.mockRejectedValue(originalError); + + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + await expect( + tool.execute({}, { messages: [], toolCallId: 'test' }) + ).rejects.toThrow('External API failed'); + expect(mockToolCallCountUpsert).not.toHaveBeenCalled(); + expect(mockToolCallCountUpdate).not.toHaveBeenCalled(); + expect(mockCaptureEvent).toHaveBeenCalledWith('ask_mcp_tool_call_completed', expect.objectContaining({ + serverName: 'Linear', + serverUrl: 'https://linear.example.com/mcp', + toolName: 'create_issue', + qualifiedToolName: 'mcp_linear__create_issue', + success: false, + failureReason: 'Error', + })); + }); + + test('tool execute wrapper increments the raw tool call counter after success', async () => { + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverId: 'server-linear', serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + await expect( + tool.execute({ title: 'My Issue' }, { messages: [], toolCallId: 'test' }) + ).resolves.toEqual({ content: [{ type: 'text', text: 'result' }] }); + + expect(mockToolCallCountUpsert).toHaveBeenCalledWith({ + where: { + mcpServerId_toolName: { + mcpServerId: 'server-linear', + toolName: 'create_issue', + }, + }, + create: { + mcpServerId: 'server-linear', + toolName: 'create_issue', + count: 1, + }, + update: { + count: { increment: 1 }, + }, + }); + expect(mockToolCallCountUpdate).not.toHaveBeenCalled(); + expect(mockCaptureEvent).toHaveBeenCalledWith('ask_mcp_tool_call_completed', expect.objectContaining({ + source: 'sourcebot-ask-agent', + serverId: 'server-linear', + serverName: 'Linear', + serverUrl: 'https://linear.example.com/mcp', + toolName: 'create_issue', + qualifiedToolName: 'mcp_linear__create_issue', + success: true, + })); + }); + + test('tool execute wrapper includes analytics context in tool completion events', async () => { + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverId: 'server-linear', serverName: 'Linear' }), + ], { + chatId: 'chat-id', + traceId: 'trace-id', + source: 'sourcebot-ask-agent', + }); + + const tool = result.tools['mcp_linear__create_issue']; + await tool.execute({ title: 'My Issue' }, { messages: [], toolCallId: 'test' }); + + expect(mockCaptureEvent).toHaveBeenCalledWith('ask_mcp_tool_call_completed', expect.objectContaining({ + chatId: 'chat-id', + traceId: 'trace-id', + source: 'sourcebot-ask-agent', + })); + }); + + test('tool execute wrapper waits for the counter increment before resolving', async () => { + let resolveCounter: (() => void) | undefined; + mockToolCallCountUpsert.mockImplementationOnce(() => new Promise((resolve) => { + resolveCounter = resolve; + })); + + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues', annotations: { readOnlyHint: true } }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverId: 'server-linear', serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__list_issues']; + const execution = tool.execute({}, { messages: [], toolCallId: 'test' }); + let didResolve = false; + const observedExecution = execution.then((value) => { + didResolve = true; + return value; + }); + + await vi.waitFor(() => { + expect(mockToolCallCountUpsert).toHaveBeenCalledTimes(1); + }); + await Promise.resolve(); + + expect(resolveCounter).toBeDefined(); + expect(didResolve).toBe(false); + + resolveCounter?.(); + + await expect(observedExecution).resolves.toEqual({ content: [{ type: 'text', text: 'result' }] }); + expect(didResolve).toBe(true); + }); + + test('tool execute wrapper retries with an atomic update after a unique conflict', async () => { + const uniqueConflict = new Prisma.PrismaClientKnownRequestError('Unique constraint failed', { + code: 'P2002', + clientVersion: '0', + }); + mockToolCallCountUpsert.mockRejectedValueOnce(uniqueConflict); + + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverId: 'server-linear', serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + await expect( + tool.execute({ title: 'My Issue' }, { messages: [], toolCallId: 'test' }) + ).resolves.toEqual({ content: [{ type: 'text', text: 'result' }] }); + + expect(mockToolCallCountUpdate).toHaveBeenCalledWith({ + where: { + mcpServerId_toolName: { + mcpServerId: 'server-linear', + toolName: 'create_issue', + }, + }, + data: { + count: { increment: 1 }, + }, + }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts b/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts new file mode 100644 index 000000000..4e249247b --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts @@ -0,0 +1,259 @@ +import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; +import { McpToolSet } from './mcpClientFactory'; +import { createLogger, env } from '@sourcebot/shared'; +import Ajv from 'ajv'; +import { jsonSchema, ToolExecutionOptions } from 'ai'; +import type { JSONSchema7, JSONSchema7Definition } from 'json-schema'; +import { getExternalMcpErrorLogFields } from './externalMcpError'; +import { getMcpFaviconUrl } from './utils'; +import { __unsafePrisma } from '@/prisma'; +import { Prisma } from '@sourcebot/db'; +import { captureEvent } from '@/lib/posthog'; +import type { AskMcpAnalyticsSource } from '@/lib/posthogEvents'; + +const logger = createLogger('mcp-tool-sets'); +const ajv = new Ajv({ allErrors: true, strict: false }); + +class McpToolTimeoutError extends Error { + constructor(toolName: string, timeoutMs: number) { + super(`MCP tool "${toolName}" timed out after ${timeoutMs}ms`); + this.name = 'McpToolTimeoutError'; + } +} + +async function incrementMcpToolCallCounter(serverId: string, toolName: string) { + try { + await __unsafePrisma.mcpServerToolCallCount.upsert({ + where: { + mcpServerId_toolName: { + mcpServerId: serverId, + toolName, + }, + }, + create: { + mcpServerId: serverId, + toolName, + count: 1, + }, + update: { + count: { increment: 1 }, + }, + }); + } catch (error) { + if (!(error instanceof Prisma.PrismaClientKnownRequestError) || error.code !== 'P2002') { + throw error; + } + + await __unsafePrisma.mcpServerToolCallCount.update({ + where: { + mcpServerId_toolName: { + mcpServerId: serverId, + toolName, + }, + }, + data: { + count: { increment: 1 }, + }, + }); + } +} + +export interface McpToolsResult { + tools: Record>[string]>; + failedServers: string[]; + serverFaviconUrls: Record; + cleanup: () => Promise; +} + +interface McpToolsAnalyticsContext { + chatId?: string; + traceId?: string; + source: AskMcpAnalyticsSource; +} + +function getMcpToolFailureReason(error: unknown): string { + if (error instanceof McpToolTimeoutError) { + return 'timeout'; + } + + const fields = getExternalMcpErrorLogFields(error); + if (fields.reason) { + return fields.reason; + } + if (fields.oauthError) { + return fields.oauthError; + } + if (fields.statusCode) { + return `status_${fields.statusCode}`; + } + if (fields.errorClass) { + return fields.errorClass; + } + + return 'unknown'; +} + +/** + * Creates MCPClients from authenticated transports, retrieves their tools, + * and returns a namespaced tool record + cleanup function. + */ +export async function getMcpTools(clients: McpToolSet[], analyticsContext?: McpToolsAnalyticsContext): Promise { + const allTools: McpToolsResult['tools'] = {}; + const failedServers: string[] = []; + const serverFaviconUrls: Record = {}; + const mcpClients: MCPClient[] = []; + + const connectionTimeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + for (const { serverId, serverName, sanitizedName, serverUrl, transport } of clients) { + try { + const mcpClient = await Promise.race([ + createMCPClient({ transport }), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Connection to MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + mcpClients.push(mcpClient); + + const toolDefinitions = await Promise.race([ + mcpClient.listTools(), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Listing tools from MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + const tools = mcpClient.toolsFromDefinitions(toolDefinitions); + const prefix = `mcp_${sanitizedName}`; + + for (const [toolName, tool] of Object.entries(tools)) { + const def = toolDefinitions.tools.find(t => t.name === toolName); + const isReadOnly = (def?.annotations as Record | undefined)?.readOnlyHint === true; + + // The @ai-sdk/mcp library sets additionalProperties: false in the JSON schema + // sent to the model, but does NOT provide a validate function — so the AI SDK + // skips server-side validation entirely. We compile the schema with ajv to + // enforce parameter names at runtime, which allows experimental_repairToolCall + // to fire on InvalidToolInputError. + const rawSchema = def?.inputSchema ?? { type: 'object', properties: {} }; + const schema = { + ...rawSchema, + type: 'object' as const, + properties: (rawSchema.properties ?? {}) as Record, + additionalProperties: false, + } satisfies JSONSchema7; + const validate = ajv.compile(schema); + const validProperties = Object.keys(schema.properties); + const validatedInputSchema = jsonSchema(schema, { + validate: async (value: unknown) => { + if (validate(value)) { + return { success: true as const, value }; + } + return { + success: false as const, + error: new Error( + `${ajv.errorsText(validate.errors)}. The valid parameter names for this tool are: [${validProperties.join(', ')}]` + ), + }; + }, + }); + + const originalExecute = tool.execute; + const qualifiedName = `${prefix}__${toolName}`; + const timeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + const executeWithTimeout = (async (input: unknown, options: ToolExecutionOptions) => { + const startTime = Date.now(); + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const combinedSignal = options.abortSignal + ? AbortSignal.any([options.abortSignal, timeoutSignal]) + : timeoutSignal; + let success = false; + let failureReason: string | undefined; + + try { + const result = await originalExecute(input, { + ...options, + abortSignal: combinedSignal, + }); + + // Await the analytics write before returning the tool result so a later + // denied approval cannot end the turn before earlier reads are counted. + await incrementMcpToolCallCounter(serverId, toolName).catch((error) => { + logger.warn('Failed to increment MCP tool call counter', { + serverId, + toolName: qualifiedName, + error: error instanceof Error ? error.message : String(error), + }); + }); + + success = true; + return result; + } catch (error) { + if (timeoutSignal.aborted) { + logger.warn(`MCP tool "${qualifiedName}" timed out after ${timeoutMs}ms`); + const timeoutError = new McpToolTimeoutError(qualifiedName, timeoutMs); + failureReason = getMcpToolFailureReason(timeoutError); + throw timeoutError; + } + failureReason = getMcpToolFailureReason(error); + throw error; + } finally { + void captureEvent('ask_mcp_tool_call_completed', { + chatId: analyticsContext?.chatId, + traceId: analyticsContext?.traceId, + source: analyticsContext?.source ?? 'sourcebot-ask-agent', + serverId, + serverName, + serverUrl, + toolName, + qualifiedToolName: qualifiedName, + success, + durationMs: Date.now() - startTime, + ...(failureReason ? { failureReason } : {}), + }); + } + }) as typeof originalExecute; + + allTools[qualifiedName] = { + ...tool, + execute: executeWithTimeout, + // The @ai-sdk/mcp package bundles its own copy of @ai-sdk/provider-utils, + // so its Schema isn't structurally identical to the workspace copy. + // The runtime shape is the same; cast through `any` to bridge the duplicate + // type identity (the two FlexibleSchema types differ only by their internal + // schemaSymbol brand). + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputSchema: validatedInputSchema as any, + ...(isReadOnly ? {} : { needsApproval: true }), + }; + } + + const faviconUrl = getMcpFaviconUrl(serverUrl, serverName); + if (faviconUrl) { + serverFaviconUrls[sanitizedName] = faviconUrl; + } + } catch (error) { + logger.error('Failed to get tools from MCP server.', { + serverId, + sanitizedName, + error: getExternalMcpErrorLogFields(error), + }); + failedServers.push(serverName); + } + } + + const cleanup = async () => { + await Promise.allSettled( + mcpClients.map(async (client) => { + try { + await client.close(); + } catch (error) { + logger.error('Error closing MCP client.', { + error: getExternalMcpErrorLogFields(error), + }); + } + }) + ); + }; + + return { tools: allTools, failedServers, serverFaviconUrls, cleanup }; +} diff --git a/packages/web/src/ee/features/chat/mcp/prefabMcpServers.test.ts b/packages/web/src/ee/features/chat/mcp/prefabMcpServers.test.ts new file mode 100644 index 000000000..7ac0face8 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/prefabMcpServers.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, test } from 'vitest'; +import { + getAvailablePrefabMcpServers, + normalizeMcpServerUrlForComparison, + PREFAB_MCP_SERVERS, +} from './prefabMcpServers'; + +describe('prefab MCP servers', () => { + test('ships the supported prefab servers', () => { + expect(PREFAB_MCP_SERVERS).toEqual([ + { + id: 'atlassian', + name: 'Atlassian', + serverUrl: 'https://mcp.atlassian.com/v1/mcp/authv2', + }, + { + id: 'linear', + name: 'Linear', + serverUrl: 'https://mcp.linear.app/mcp', + }, + { + id: 'posthog', + name: 'PostHog', + serverUrl: 'https://mcp.posthog.com/mcp', + }, + { + id: 'slack', + name: 'Slack', + serverUrl: 'https://mcp.slack.com/mcp', + }, + ]); + }); + + test('keeps prefab servers sorted alphabetically by name', () => { + const sortedNames = PREFAB_MCP_SERVERS.map((server) => server.name).sort((a, b) => a.localeCompare(b)); + + expect(PREFAB_MCP_SERVERS.map((server) => server.name)).toEqual(sortedNames); + }); + + test('hides already configured prefab servers after URL normalization', () => { + const availableServers = getAvailablePrefabMcpServers(['https://mcp.slack.com/mcp/']); + + expect(availableServers.map((server) => server.id)).toEqual(['atlassian', 'linear', 'posthog']); + }); + + test('hides the Atlassian prefab entry when the shared endpoint is configured', () => { + const availableServers = getAvailablePrefabMcpServers(['https://mcp.atlassian.com/v1/mcp/authv2/']); + + expect(availableServers.map((server) => server.id)).toEqual(['linear', 'posthog', 'slack']); + }); + + test('normalizes server URLs for duplicate comparisons', () => { + expect(normalizeMcpServerUrlForComparison(' HTTPS://MCP.SLACK.COM/mcp/#connect ')).toBe('https://mcp.slack.com/mcp'); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/prefabMcpServers.ts b/packages/web/src/ee/features/chat/mcp/prefabMcpServers.ts new file mode 100644 index 000000000..8c11e195e --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/prefabMcpServers.ts @@ -0,0 +1,52 @@ +export interface PrefabMcpServer { + id: string; + name: string; + serverUrl: string; +} + +const prefabMcpServers = [ + { + id: "atlassian", + name: "Atlassian", + serverUrl: "https://mcp.atlassian.com/v1/mcp/authv2", + }, + { + id: "linear", + name: "Linear", + serverUrl: "https://mcp.linear.app/mcp", + }, + { + id: "posthog", + name: "PostHog", + serverUrl: "https://mcp.posthog.com/mcp", + }, + { + id: "slack", + name: "Slack", + serverUrl: "https://mcp.slack.com/mcp", + }, +] satisfies PrefabMcpServer[]; + +export const PREFAB_MCP_SERVERS = [...prefabMcpServers].sort((a, b) => a.name.localeCompare(b.name)); + +export function normalizeMcpServerUrlForComparison(serverUrl: string): string { + const trimmedServerUrl = serverUrl.trim(); + + try { + const url = new URL(trimmedServerUrl); + url.hash = ""; + return url.toString().replace(/\/$/, ""); + } catch { + return trimmedServerUrl.toLowerCase().replace(/\/$/, ""); + } +} + +export function getAvailablePrefabMcpServers(configuredServerUrls: string[]): PrefabMcpServer[] { + const configuredServerUrlSet = new Set( + configuredServerUrls.map((serverUrl) => normalizeMcpServerUrlForComparison(serverUrl)), + ); + + return PREFAB_MCP_SERVERS.filter((server) => ( + !configuredServerUrlSet.has(normalizeMcpServerUrlForComparison(server.serverUrl)) + )); +} diff --git a/packages/web/src/ee/features/chat/mcp/queryKeys.test.ts b/packages/web/src/ee/features/chat/mcp/queryKeys.test.ts new file mode 100644 index 000000000..56caa6340 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/queryKeys.test.ts @@ -0,0 +1,17 @@ +import { describe, expect, test, vi } from 'vitest'; +import type { QueryClient } from '@tanstack/react-query'; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from './queryKeys'; + +describe('invalidateMcpConfigurationQueries', () => { + test('invalidates admin configuration, account MCP server status, and tool metadata', async () => { + const queryClient = { + invalidateQueries: vi.fn().mockResolvedValue(undefined), + } as unknown as QueryClient; + + await invalidateMcpConfigurationQueries(queryClient); + + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: mcpQueryKeys.configuration }); + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: mcpQueryKeys.serversWithStatus }); + expect(queryClient.invalidateQueries).toHaveBeenCalledWith({ queryKey: mcpQueryKeys.tools }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/queryKeys.ts b/packages/web/src/ee/features/chat/mcp/queryKeys.ts new file mode 100644 index 000000000..40575dae1 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/queryKeys.ts @@ -0,0 +1,15 @@ +import type { QueryClient } from '@tanstack/react-query'; + +export const mcpQueryKeys = { + serversWithStatus: ['mcpServersWithStatus'] as const, + configuration: ['mcpConfiguration'] as const, + tools: ['mcpTools'] as const, +}; + +export async function invalidateMcpConfigurationQueries(queryClient: QueryClient) { + await Promise.all([ + queryClient.invalidateQueries({ queryKey: mcpQueryKeys.configuration }), + queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }), + queryClient.invalidateQueries({ queryKey: mcpQueryKeys.tools }), + ]); +} diff --git a/packages/web/src/ee/features/chat/mcp/types.ts b/packages/web/src/ee/features/chat/mcp/types.ts new file mode 100644 index 000000000..0d1c099ae --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/types.ts @@ -0,0 +1,53 @@ +export interface McpConfigurationServer { + id: string; + name: string; + serverUrl: string; + sanitizedName: string; + faviconUrl: string | undefined; + savedConnectionCount: number; + toolUsage: McpServerToolUsageSummary; +} + +export type McpConfigurationAllowedMode = 'approved_only'; + +export interface McpToolUsageEntry { + toolName: string; + totalCalls: number; + usageSharePercent: number; +} + +export interface McpServerToolUsageSummary { + totalCalls: number; + usedToolCount: number; + tools: McpToolUsageEntry[]; +} + +export interface GetMcpConfigurationResponse { + servers: McpConfigurationServer[]; + allowedMode: McpConfigurationAllowedMode; + isOAuthAvailable: boolean; +} + +export interface ToolSummary { + name: string; + title?: string; + description?: string; + annotations?: { + readOnlyHint?: boolean; + destructiveHint?: boolean; + idempotentHint?: boolean; + }; +} + +export type ToolMetadataErrorReason = + | 'timeout' + | 'auth_failed' + | 'connection_failed' + | 'unsupported' + | 'unknown'; + +export type ServerToolsEntry = + | { status: 'available'; serverId: string; tools: ToolSummary[]; truncated?: boolean } + | { status: 'error'; serverId: string; reason: ToolMetadataErrorReason }; + +export type GetMcpToolsResponse = ServerToolsEntry[]; diff --git a/packages/web/src/ee/features/chat/mcp/utils.test.ts b/packages/web/src/ee/features/chat/mcp/utils.test.ts new file mode 100644 index 000000000..d3c887fc7 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/utils.test.ts @@ -0,0 +1,50 @@ +import { expect, test, describe } from 'vitest'; +import { getMcpFaviconUrl, sanitizeMcpServerName } from './utils'; + +describe('sanitizeMcpServerName', () => { + test('lowercases ASCII letters', () => { + expect(sanitizeMcpServerName('MyServer')).toBe('myserver'); + }); + + test('replaces special characters with underscores', () => { + expect(sanitizeMcpServerName('My Server!')).toBe('my_server_'); + }); + + test('preserves digits', () => { + expect(sanitizeMcpServerName('server123')).toBe('server123'); + }); + + test('replaces spaces and hyphens', () => { + expect(sanitizeMcpServerName('my-cool server')).toBe('my_cool_server'); + }); + + test('handles empty string', () => { + expect(sanitizeMcpServerName('')).toBe(''); + }); + + test('replaces unicode characters with underscores', () => { + expect(sanitizeMcpServerName('Ñoño')).toBe('_o_o'); + }); + + test('replaces all special characters', () => { + expect(sanitizeMcpServerName('@#$%')).toBe('____'); + }); + + test('returns already sanitized name unchanged', () => { + expect(sanitizeMcpServerName('linear')).toBe('linear'); + }); +}); + +describe('getMcpFaviconUrl', () => { + test('returns a Google favicon URL for a valid server URL', () => { + expect(getMcpFaviconUrl('https://mcp.linear.app/mcp')).toBe('https://www.google.com/s2/favicons?domain=https://mcp.linear.app&sz=32'); + }); + + test('returns a local Atlassian icon for the Atlassian prefab server', () => { + expect(getMcpFaviconUrl('https://mcp.atlassian.com/v1/mcp/authv2', 'Atlassian')).toMatch(/^data:image\/svg\+xml,/); + }); + + test('returns undefined for a malformed server URL', () => { + expect(getMcpFaviconUrl('not a url')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/utils.ts b/packages/web/src/ee/features/chat/mcp/utils.ts new file mode 100644 index 000000000..5d1453cbb --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/utils.ts @@ -0,0 +1,81 @@ +/** + * Sanitizes an MCP server name into a lowercase alphanumeric string suitable + * for use as a tool-name prefix (e.g. "My Server!" → "my_server_"). + * + * This is used to namespace MCP tools (mcp_{sanitizedName}__{toolName}) and + * to key favicon maps. Must be kept consistent everywhere — collisions on + * this value are prevented at server-creation time. + */ +export function sanitizeMcpServerName(name: string): string { + return name.toLowerCase().replace(/[^a-z0-9]/g, '_'); +} + +export function pluralize(count: number, singular: string, plural = `${singular}s`) { + return count === 1 ? singular : plural; +} + +const standardNumberFormatter = new Intl.NumberFormat(); +const compactNumberFormatter = new Intl.NumberFormat(undefined, { + notation: "compact", + maximumFractionDigits: 1, +}); + +export function formatCount(count: number) { + if (count >= 10_000) { + return compactNumberFormatter.format(count); + } + return standardNumberFormatter.format(count); +} + +export function formatUsageSharePercent(percent: number) { + if (percent <= 0) { + return "0%"; + } + if (percent < 1) { + return "<1%"; + } + if (percent < 10) { + return `${percent.toFixed(1).replace(/\.0$/, "")}%`; + } + return `${Math.round(percent)}%`; +} + +function createMcpIconDataUri(svg: string): string { + return `data:image/svg+xml,${encodeURIComponent(svg)}`; +} + +const atlassianIconSvg = ` + + + + + + + + + + + + + +`; + +const knownMcpFaviconUrlsBySanitizedName: Record = { + atlassian: createMcpIconDataUri(atlassianIconSvg), +}; + +export function getMcpFaviconUrl(serverUrl: string, serverName?: string): string | undefined { + if (serverName) { + const knownFaviconUrl = knownMcpFaviconUrlsBySanitizedName[sanitizeMcpServerName(serverName)]; + if (knownFaviconUrl) { + return knownFaviconUrl; + } + } + + try { + const origin = new URL(serverUrl).origin; + return `https://www.google.com/s2/favicons?domain=${origin}&sz=32`; + } catch { + return undefined; + } +} diff --git a/packages/web/src/features/chat/agent.test.ts b/packages/web/src/features/chat/agent.test.ts new file mode 100644 index 000000000..e7984e655 --- /dev/null +++ b/packages/web/src/features/chat/agent.test.ts @@ -0,0 +1,266 @@ +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import type { ModelMessage } from 'ai'; +import type { SBChatMessage, SBChatMessagePart } from './types'; + +const mockLogger = vi.hoisted(() => ({ + debug: vi.fn(), + error: vi.fn(), + info: vi.fn(), + warn: vi.fn(), +})); + +const mockAi = vi.hoisted(() => ({ + convertToModelMessages: vi.fn(), + createUIMessageStream: vi.fn(), + latestCreateUIMessageStreamOptions: undefined as undefined | { + execute: (args: { writer: unknown }) => Promise | void; + }, + streamText: vi.fn(), +})); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => mockLogger, + env: { + SOURCEBOT_CHAT_FILE_MAX_CHARACTERS: 4000, + SOURCEBOT_CHAT_MAX_STEP_COUNT: 8, + SOURCEBOT_CHAT_MODEL_TEMPERATURE: 0, + SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED: 'false', + }, + getDBConnectionString: () => 'postgresql://sourcebot:sourcebot@db.example.com:5432/sourcebot', +})); + +vi.mock('server-only', () => ({})); + +vi.mock('@/ee/features/chat/mcp/mcpClientFactory', () => ({ + getConnectedMcpClients: vi.fn(), +})); + +vi.mock('@/ee/features/chat/mcp/mcpToolRegistry', () => ({ + buildMcpToolRegistry: vi.fn(() => []), + searchMcpTools: vi.fn(() => []), +})); + +vi.mock('@/ee/features/chat/mcp/mcpToolSets', () => ({ + getMcpTools: vi.fn(), +})); + +vi.mock('@/features/git', () => ({ + getFileSource: vi.fn(), +})); + +vi.mock('@/features/tools', () => { + const createToolDefinition = (name: string) => ({ + name, + title: name, + description: `${name} description`, + inputSchema: {}, + isReadOnly: true, + isIdempotent: true, + execute: vi.fn(), + }); + + return { + findSymbolDefinitionsDefinition: createToolDefinition('find_symbol_definitions'), + findSymbolReferencesDefinition: createToolDefinition('find_symbol_references'), + getDiffDefinition: createToolDefinition('get_diff'), + globDefinition: createToolDefinition('glob'), + grepDefinition: createToolDefinition('grep'), + listCommitsDefinition: createToolDefinition('list_commits'), + listReposDefinition: createToolDefinition('list_repos'), + listTreeDefinition: createToolDefinition('list_tree'), + readFileDefinition: createToolDefinition('read_file'), + toVercelAITool: vi.fn((definition: { name: string }) => ({ + name: definition.name, + })), + }; +}); + +vi.mock('@/lib/entitlements', () => ({ + hasEntitlement: vi.fn(() => false), +})); + +vi.mock('@/lib/posthog', () => ({ + captureEvent: vi.fn(), +})); + +vi.mock('ai', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + convertToModelMessages: mockAi.convertToModelMessages, + createUIMessageStream: mockAi.createUIMessageStream, + streamText: mockAi.streamText, + }; +}); + +const { createMessageStream } = await import('./agent'); + +const listReposInput = { + sort: 'name', + page: 1, + perPage: 30, + direction: 'asc', +} as const; + +const dynamicApprovalRespondedPart = { + type: 'dynamic-tool', + toolName: 'mcp_linear__save_issue', + toolCallId: 'tool-call-1', + state: 'approval-responded', + input: { title: 'Issue' }, + approval: { id: 'approval-1', approved: true }, +} satisfies SBChatMessagePart; + +const staticApprovalRespondedPart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-2', + state: 'approval-responded', + input: listReposInput, + approval: { id: 'approval-2', approved: true }, +} satisfies SBChatMessagePart; + +const createUserMessage = (): SBChatMessage => ({ + id: 'user-message', + role: 'user', + parts: [ + { + type: 'text', + text: 'Create an issue', + }, + ], +}); + +const createAssistantMessage = (parts: SBChatMessagePart[]): SBChatMessage => ({ + id: 'assistant-message', + role: 'assistant', + parts, +}); + +const createFakeStreamResult = () => ({ + response: Promise.resolve(new Response()), + totalUsage: Promise.resolve({ + inputTokens: 1, + outputTokens: 1, + totalTokens: 2, + }), + toUIMessageStream: vi.fn((options?: { onFinish?: () => Promise | void }) => { + void options?.onFinish?.(); + return {}; + }), +}); + +const runCreateMessageStream = async (messages: SBChatMessage[]) => { + const convertedLastTurn: ModelMessage = { + role: 'assistant', + content: 'converted-last-turn', + }; + mockAi.convertToModelMessages.mockResolvedValue([convertedLastTurn]); + mockAi.streamText.mockReturnValue(createFakeStreamResult()); + + const props = { + chatId: 'chat-id', + messages, + selectedRepos: [], + prisma: {}, + model: {}, + modelName: 'test-model', + onFinish: vi.fn(), + onError: () => 'error', + } as unknown as Parameters[0]; + + await createMessageStream(props); + + const execute = mockAi.latestCreateUIMessageStreamOptions?.execute; + if (!execute) { + throw new Error('Expected createUIMessageStream to capture execute callback.'); + } + + await execute({ + writer: { + merge: vi.fn(), + write: vi.fn(), + }, + }); + + const streamTextArgs = mockAi.streamText.mock.calls.at(-1)?.[0]; + if (!streamTextArgs || typeof streamTextArgs !== 'object' || !('messages' in streamTextArgs)) { + throw new Error('Expected streamText to be called with messages.'); + } + + return streamTextArgs.messages as ModelMessage[]; +}; + +beforeEach(() => { + vi.clearAllMocks(); + mockAi.latestCreateUIMessageStreamOptions = undefined; + mockAi.createUIMessageStream.mockImplementation((options: typeof mockAi.latestCreateUIMessageStreamOptions) => { + mockAi.latestCreateUIMessageStreamOptions = options; + return {}; + }); +}); + +describe('createMessageStream approval continuation', () => { + test.each([ + ['dynamic', dynamicApprovalRespondedPart], + ['static', staticApprovalRespondedPart], + ])('preserves the full last turn for %s approval responses', async (_kind, approvalPart) => { + const assistantMessage = createAssistantMessage([ + { + type: 'step-start', + }, + { + type: 'text', + text: 'I have everything I need. Let me now create the issue.', + }, + approvalPart, + ]); + + const streamTextMessages = await runCreateMessageStream([ + createUserMessage(), + assistantMessage, + ]); + + expect(mockAi.convertToModelMessages).toHaveBeenCalledWith( + [assistantMessage], + { ignoreIncompleteToolCalls: true } + ); + expect(streamTextMessages).toEqual([ + { + role: 'user', + content: 'Create an issue', + }, + { + role: 'assistant', + content: 'converted-last-turn', + }, + ]); + }); + + test('does not treat untagged latest approval-continuation text as a prior assistant answer', async () => { + const assistantMessage = createAssistantMessage([ + { + type: 'step-start', + }, + { + type: 'text', + text: 'I have everything I need. Let me now create the Linear issue!', + }, + dynamicApprovalRespondedPart, + ]); + + const streamTextMessages = await runCreateMessageStream([ + createUserMessage(), + assistantMessage, + ]); + + expect(streamTextMessages).not.toContainEqual({ + role: 'assistant', + content: [ + { + type: 'text', + text: 'I have everything I need. Let me now create the Linear issue!', + }, + ], + }); + }); +}); diff --git a/packages/web/src/features/chat/agent.ts b/packages/web/src/features/chat/agent.ts index 0efb706fc..f4bd96854 100644 --- a/packages/web/src/features/chat/agent.ts +++ b/packages/web/src/features/chat/agent.ts @@ -1,22 +1,31 @@ import { SBChatMessage, SBChatMessageMetadata } from "@/features/chat/types"; -import { getAnswerPartFromAssistantMessage } from "@/features/chat/utils"; import { getFileSource } from '@/features/git'; import { isServiceError } from "@/lib/utils"; import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; import { ProviderOptions } from "@ai-sdk/provider-utils"; +import type { PrismaClient } from "@sourcebot/db"; import { createLogger, env } from "@sourcebot/shared"; import { + convertToModelMessages, createUIMessageStream, JSONValue, LanguageModel, ModelMessage, StopCondition, streamText, StreamTextResult, UIMessageStreamOnFinishCallback, UIMessageStreamOptions, - UIMessageStreamWriter + UIMessageStreamWriter, + tool, + Tool, + NoSuchToolError, } from "ai"; +import { z } from "zod"; import { randomUUID } from "crypto"; import _dedent from "dedent"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "./constants"; import { Source } from "./types"; -import { addLineNumbers, fileReferenceToString } from "./utils"; +import { addLineNumbers, fileReferenceToString, getAnswerPartFromAssistantMessage, getTurnProgressState } from "./utils"; import { createTools } from "./tools"; +import { getConnectedMcpClients } from "@/ee/features/chat/mcp/mcpClientFactory"; +import { getMcpTools, McpToolsResult } from "@/ee/features/chat/mcp/mcpToolSets"; +import { buildMcpToolRegistry, McpToolRegistryEntry, searchMcpTools } from "@/ee/features/chat/mcp/mcpToolRegistry"; +import { hasEntitlement } from '@/lib/entitlements'; const dedent = _dedent.withOptions({ alignValues: true }); @@ -36,6 +45,10 @@ interface CreateMessageStreamResponseProps { chatId: string; messages: SBChatMessage[]; selectedRepos: string[]; + prisma: PrismaClient; + // When undefined, MCP tools are disabled entirely (e.g. programmatic callers like askCodebase). + // When an array, MCP tools are enabled for all servers not in the list. + disabledMcpServerIds?: string[]; model: AISDKLanguageModelV3; modelName: string; onFinish: UIMessageStreamOnFinishCallback; @@ -43,6 +56,8 @@ interface CreateMessageStreamResponseProps { modelProviderOptions?: Record>; modelTemperature?: number; metadata?: Partial; + userId?: string; + orgId?: number; } export const createMessageStream = async ({ @@ -50,12 +65,16 @@ export const createMessageStream = async ({ messages, metadata, selectedRepos, + prisma, + disabledMcpServerIds, model, modelName, modelProviderOptions, modelTemperature, onFinish, onError, + userId, + orgId, }: CreateMessageStreamResponseProps) => { const latestMessage = messages[messages.length - 1]; const sources = latestMessage.parts @@ -66,8 +85,10 @@ export const createMessageStream = async ({ // Extract user messages and assistant answers. // We will use this as the context we carry between messages. - const messageHistory = - messages.map((message): ModelMessage | undefined => { + // Server requests always receive persisted messages between client streams, so evaluate them in the ready state. + const incomingTurnProgress = getTurnProgressState({ messages, status: 'ready' }); + let messageHistory: ModelMessage[] = + messages.map((message, index): ModelMessage | undefined => { if (message.role === 'user') { return { role: 'user', @@ -76,7 +97,10 @@ export const createMessageStream = async ({ } if (message.role === 'assistant') { - const answerPart = getAnswerPartFromAssistantMessage(message, false); + const isLatestIncompleteAssistantMessage = + index === messages.length - 1 && + incomingTurnProgress.isTurnInProgress; + const answerPart = getAnswerPartFromAssistantMessage(message, isLatestIncompleteAssistantMessage); if (answerPart) { return { role: 'assistant', @@ -86,6 +110,29 @@ export const createMessageStream = async ({ } }).filter(message => message !== undefined); + // When the last assistant turn has approval responses (from the tool approval flow), + // the turn is incomplete — it has no answer text, only a pending tool call that was + // approved. We need to preserve the full tool call + approval so streamText can + // execute the approved tool and continue. + const lastMsg = messages[messages.length - 1]; + const hasApprovalContinuationReady = + lastMsg?.role === 'assistant' && + incomingTurnProgress.hasApprovalContinuationReady; + + // When continuing after tool approval, capture the prior turn's metadata + // so we can aggregate token counts and response times across phases. + const priorMetadata = hasApprovalContinuationReady + ? (lastMsg.metadata as SBChatMessageMetadata | undefined) + : undefined; + + if (hasApprovalContinuationReady) { + const fullLastTurn = await convertToModelMessages( + [lastMsg], + { ignoreIncompleteToolCalls: true } + ); + messageHistory = [...messageHistory, ...fullLastTurn]; + } + const stream = createUIMessageStream({ execute: async ({ writer }) => { writer.write({ @@ -101,17 +148,34 @@ export const createMessageStream = async ({ inputMessages: messageHistory, inputSources: sources, selectedRepos, + disabledMcpServerIds, onWriteSource: (source) => { writer.write({ type: 'data-source', data: source, }); }, + onMcpServerDiscovered: (sanitizedName, faviconUrl) => { + writer.write({ + type: 'data-mcp-server', + data: { sanitizedName, faviconUrl }, + }); + }, + onMcpServerFailed: (serverName) => { + writer.write({ + type: 'data-mcp-failed-server', + data: { serverName }, + }); + }, traceId, chatId, + prisma, + userId, + orgId, }); await mergeStreamAsync(researchStream, writer, { + originalMessages: messages, sendReasoning: true, sendStart: false, sendFinish: false, @@ -122,10 +186,10 @@ export const createMessageStream = async ({ writer.write({ type: 'message-metadata', messageMetadata: { - totalTokens: totalUsage.totalTokens, - totalInputTokens: totalUsage.inputTokens, - totalOutputTokens: totalUsage.outputTokens, - totalResponseTimeMs: new Date().getTime() - startTime.getTime(), + totalTokens: (priorMetadata?.totalTokens ?? 0) + (totalUsage.totalTokens ?? 0), + totalInputTokens: (priorMetadata?.totalInputTokens ?? 0) + (totalUsage.inputTokens ?? 0), + totalOutputTokens: (priorMetadata?.totalOutputTokens ?? 0) + (totalUsage.outputTokens ?? 0), + totalResponseTimeMs: (priorMetadata?.totalResponseTimeMs ?? 0) + (new Date().getTime() - startTime.getTime()), modelName, traceId, ...metadata, @@ -149,11 +213,17 @@ interface AgentOptions { providerOptions?: ProviderOptions; temperature?: number; selectedRepos: string[]; + disabledMcpServerIds?: string[]; inputMessages: ModelMessage[]; inputSources: Source[]; onWriteSource: (source: Source) => void; + onMcpServerDiscovered: (sanitizedName: string, faviconUrl: string) => void; + onMcpServerFailed: (serverName: string) => void; traceId: string; chatId: string; + prisma: PrismaClient; + userId?: string; + orgId?: number; } const createAgentStream = async ({ @@ -163,9 +233,15 @@ const createAgentStream = async ({ inputMessages, inputSources, selectedRepos, + disabledMcpServerIds, onWriteSource, + onMcpServerDiscovered, + onMcpServerFailed, traceId, chatId, + prisma, + userId, + orgId, }: AgentOptions) => { // For every file source, resolve the source code so that we can include it in the system prompt. const fileSources = inputSources.filter((source) => source.type === 'file'); @@ -192,48 +268,166 @@ const createAgentStream = async ({ })) ).filter((source) => source !== undefined); + let mcpToolSetsObj: McpToolsResult = { tools: {}, failedServers: [], serverFaviconUrls: {}, cleanup: async () => {} }; + if (userId && orgId && await hasEntitlement('oauth') && disabledMcpServerIds !== undefined) { + try { + const allMcpClients = await getConnectedMcpClients(prisma, userId, orgId); + const mcpClients = allMcpClients.filter((c) => !disabledMcpServerIds.includes(c.serverId)); + mcpToolSetsObj = await getMcpTools(mcpClients, { + chatId, + traceId, + source: 'sourcebot-ask-agent', + }); + + for (const [sanitizedName, faviconUrl] of Object.entries(mcpToolSetsObj.serverFaviconUrls)) { + onMcpServerDiscovered(sanitizedName, faviconUrl); + } + + if (mcpClients.length > 0) { + logger.info(`Connected to ${mcpClients.length} external MCP server(s): ${mcpClients.map(c => c.serverName).join(', ')}`); + } + } catch (error) { + logger.error('Failed to connect external MCP servers:', error); + } + } + + for (const serverName of mcpToolSetsObj.failedServers) { + onMcpServerFailed(serverName); + } + + const mcpRegistry = buildMcpToolRegistry(mcpToolSetsObj.tools); + const hasMcpTools = mcpRegistry.length > 0; + + const toolRequestActivation = tool({ + description: dedent` + Activate an MCP tool by name so it becomes callable on your next step. + You MUST pass an exact tool name from the tool registry in the system prompt. + Do NOT pass natural language descriptions or sentences. + If you need multiple tools, call this once per tool. + + Examples: + CORRECT: tool_to_activate_name="mcp_linear__save_comment" + CORRECT: tool_to_activate_name="mcp_linear__create_attachment" + INCORRECT: tool_to_activate_name="create a linear issue and update status" + INCORRECT: tool_to_activate_name="find tools for commenting on issues" + `, + inputSchema: z.object({ + tool_to_activate_name: z.string().describe('Exact tool name from the registry, e.g. "mcp_linear__save_comment"'), + }), + execute: async ({ tool_to_activate_name }) => { + const results = searchMcpTools(tool_to_activate_name, mcpRegistry); + return { + results: results.map(e => ({ name: e.name, description: e.description })), + }; + }, + }); + const systemPrompt = createPrompt({ repos: selectedRepos, files: resolvedFileSources, + mcpToolRegistry: mcpRegistry, }); - const stream = streamText({ - model, - providerOptions, - messages: inputMessages, - system: systemPrompt, - tools: createTools({ source: 'sourcebot-ask-agent', selectedRepos }), - temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, - stopWhen: [ - stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), - ], - toolChoice: "auto", - onStepFinish: ({ toolResults }) => { - toolResults.forEach(({ output, dynamic }) => { - if (dynamic || isServiceError(output)) { - return; + const builtinTools = createTools({ source: 'sourcebot-ask-agent', selectedRepos }); + const builtinToolNames = Object.keys(builtinTools); + const allTools: Record = { + ...builtinTools, + ...(hasMcpTools ? { tool_request_activation: toolRequestActivation, ...mcpToolSetsObj.tools } : {}), + }; + + try { + const stream = streamText({ + model, + providerOptions, + messages: inputMessages, + system: systemPrompt, + tools: allTools, + activeTools: [ + ...builtinToolNames, + ...(hasMcpTools ? ['tool_request_activation'] : []), + ], + prepareStep: hasMcpTools ? ({ steps }) => { + const activated = new Set(); + for (const step of steps) { + for (const result of step.toolResults) { + if (!result || result.toolName !== 'tool_request_activation') { + continue; + } + const output = result.output as { results?: Array<{ name: string }> }; + for (const { name } of output?.results ?? []) { + if (name in mcpToolSetsObj.tools) { + activated.add(name); + } + } + } + } + return { + activeTools: [ + ...builtinToolNames, + 'tool_request_activation', + ...Array.from(activated), + ], + }; + } : undefined, + temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, + stopWhen: [ + stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), + ], + toolChoice: "auto", + experimental_repairToolCall: async ({ toolCall, tools, error }) => { + // Fix case mismatches (e.g. model outputs "Mcp_Linear__Save_Comment" instead of "mcp_linear__save_comment") + if (NoSuchToolError.isInstance(error)) { + const lower = toolCall.toolName.toLowerCase(); + if (lower !== toolCall.toolName && lower in tools) { + return { ...toolCall, toolName: lower }; + } } - output.sources?.forEach(onWriteSource); - }); - }, - experimental_telemetry: { - isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', - metadata: { - langfuseTraceId: traceId, + // For anything we can't fix, return null. + // The AI SDK will mark the call as invalid and pass the error + // back to the model so it can retry with correct parameters. + logger.warn(`Tool call repair failed for "${toolCall.toolName}": ${error.message}`); + return null; }, - }, - onError: (error) => { - logger.error(error); - }, - }); + onStepFinish: ({ toolResults }) => { + toolResults.forEach(({ output, dynamic }) => { + if (dynamic || isServiceError(output)) { + return; + } - return stream; + output.sources?.forEach(onWriteSource); + }); + }, + experimental_telemetry: { + isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', + metadata: { + langfuseTraceId: traceId, + }, + }, + onError: (error) => { + logger.error(error); + }, + }); + + // Clean up MCP transport connections once the stream completes (success or failure). + stream.response.then( + () => mcpToolSetsObj.cleanup(), + () => mcpToolSetsObj.cleanup() + ); + return stream; + } catch (error) { + // If anything between MCP setup and stream return throws, ensure we + // still close the MCP transport connections to avoid leaking them. + await mcpToolSetsObj.cleanup(); + throw error; + } } + const createPrompt = ({ files, repos, + mcpToolRegistry, }: { files?: { path: string; @@ -243,6 +437,7 @@ const createPrompt = ({ revision: string; }[], repos: string[], + mcpToolRegistry: McpToolRegistryEntry[], }) => { return dedent` You are a powerful agentic AI code assistant built into Sourcebot, the world's best code-intelligence platform. Your job is to help developers understand and navigate their large codebases. @@ -287,6 +482,18 @@ const createPrompt = ({ `: ''} + ${(mcpToolRegistry.length > 0) ? dedent` + + External MCP tools are available but must first be activated via \`tool_request_activation\`. + + **CRITICAL**: The list below is the complete and authoritative inventory of all tools available to you: + ${mcpToolRegistry.map(e => `- ${e.name}: ${e.description}`).join('\n')} + + **How to use tool_request_activation**: Pass the exact tool name from the list above as the \`tool_to_activate_name\` parameter. Do NOT pass natural language descriptions or sentences. If you need multiple tools, call \`tool_request_activation\` once per tool. + Example: to activate the comment tool, call \`tool_request_activation\` with tool_to_activate_name="mcp_linear__save_comment", NOT tool_to_activate_name="save a comment on an issue". + + ` : ''} + When you have sufficient context, output your answer as a structured markdown response. diff --git a/packages/web/src/features/chat/askMcpAnalytics.server.ts b/packages/web/src/features/chat/askMcpAnalytics.server.ts new file mode 100644 index 000000000..67e1ab6df --- /dev/null +++ b/packages/web/src/features/chat/askMcpAnalytics.server.ts @@ -0,0 +1,136 @@ +import { getStoredMcpConnectionStatus } from "@/ee/features/chat/mcp/connectionStatus"; +import { hasEntitlement } from "@/lib/entitlements"; +import type { PrismaClient } from "@sourcebot/db"; +import type { DynamicToolUIPart } from "ai"; +import type { SBChatMessage, SBChatMessagePart } from "./types"; +import { getTurnProgressState } from "./utils"; + +export type AskMcpAvailabilityAnalytics = { + hasAskMcpServersAvailable: boolean; + askMcpConnectedServerCount: number; + askMcpEnabledServerCount: number; + askMcpDisabledServerCount: number; +}; + +export type AskMcpTurnCompletedAnalytics = { + traceId?: string; + askMcpUsed: boolean; + askMcpToolCallCount: number; + askMcpToolSuccessCount: number; + askMcpToolFailureCount: number; + askMcpApprovalRequestedCount: number; + askMcpApprovalDeniedCount: number; + askMcpFailedServerCount: number; + durationMs: number; +}; + +const emptyAskMcpAvailability: AskMcpAvailabilityAnalytics = { + hasAskMcpServersAvailable: false, + askMcpConnectedServerCount: 0, + askMcpEnabledServerCount: 0, + askMcpDisabledServerCount: 0, +}; + +type AskMcpAvailabilityPrismaClient = Pick; + +export async function getAskMcpAvailabilityAnalytics({ + prisma, + userId, + orgId, + disabledMcpServerIds, +}: { + prisma: AskMcpAvailabilityPrismaClient; + userId: string | undefined; + orgId: number; + disabledMcpServerIds: string[]; +}): Promise { + if (!userId || !(await hasEntitlement("oauth"))) { + return emptyAskMcpAvailability; + } + + const userServers = await prisma.userMcpServer.findMany({ + where: { + userId, + tokens: { not: null }, + server: { + orgId, + clientInfo: { not: null }, + }, + }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + + const connectedServerIds = userServers + .filter((userServer) => + getStoredMcpConnectionStatus(userServer.tokens, userServer.tokensExpiresAt).state === "connected" + ) + .map((userServer) => userServer.serverId); + const disabledServerIds = new Set(disabledMcpServerIds); + const askMcpDisabledServerCount = connectedServerIds.filter((serverId) => disabledServerIds.has(serverId)).length; + const askMcpEnabledServerCount = connectedServerIds.length - askMcpDisabledServerCount; + + return { + hasAskMcpServersAvailable: askMcpEnabledServerCount > 0, + askMcpConnectedServerCount: connectedServerIds.length, + askMcpEnabledServerCount, + askMcpDisabledServerCount, + }; +} + +function isExternalMcpToolPart(part: SBChatMessagePart): part is SBChatMessagePart & DynamicToolUIPart { + return part.type === "dynamic-tool" && part.toolName.startsWith("mcp_"); +} + +function hasApproval(part: DynamicToolUIPart) { + return part.approval !== undefined; +} + +export function getAskMcpTurnCompletedAnalytics({ + messages, + availability, +}: { + messages: SBChatMessage[]; + availability: AskMcpAvailabilityAnalytics; +}): AskMcpTurnCompletedAnalytics | undefined { + const latestMessage = messages.at(-1); + const latestAssistantMessage = latestMessage?.role === "assistant" ? latestMessage : undefined; + if (!latestAssistantMessage) { + return undefined; + } + + const progressState = getTurnProgressState({ messages, status: "ready" }); + if (progressState.isTurnInProgress) { + return undefined; + } + + const externalMcpToolParts = latestAssistantMessage.parts.filter(isExternalMcpToolPart); + const askMcpToolSuccessCount = externalMcpToolParts.filter((part) => part.state === "output-available").length; + const askMcpToolFailureCount = externalMcpToolParts.filter((part) => part.state === "output-error").length; + const askMcpToolCallCount = askMcpToolSuccessCount + askMcpToolFailureCount; + const askMcpApprovalRequestedCount = externalMcpToolParts.filter(hasApproval).length; + const askMcpApprovalDeniedCount = externalMcpToolParts.filter((part) => part.state === "output-denied").length; + const askMcpFailedServerCount = latestAssistantMessage.parts.filter((part) => + part.type === "data-mcp-failed-server" + ).length; + + const hasMcpTurnActivity = externalMcpToolParts.length > 0 || askMcpFailedServerCount > 0; + if (!availability.hasAskMcpServersAvailable && !hasMcpTurnActivity) { + return undefined; + } + + return { + traceId: latestAssistantMessage.metadata?.traceId, + askMcpUsed: askMcpToolCallCount > 0, + askMcpToolCallCount, + askMcpToolSuccessCount, + askMcpToolFailureCount, + askMcpApprovalRequestedCount, + askMcpApprovalDeniedCount, + askMcpFailedServerCount, + durationMs: latestAssistantMessage.metadata?.totalResponseTimeMs ?? 0, + }; +} diff --git a/packages/web/src/features/chat/components/chatBox/chatBox.tsx b/packages/web/src/features/chat/components/chatBox/chatBox.tsx index 441caa220..51d9a3f45 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBox.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBox.tsx @@ -34,7 +34,8 @@ interface ChatBoxProps { preferredSuggestionsBoxPlacement?: "top-start" | "bottom-start"; className?: string; isRedirecting?: boolean; - isGenerating?: boolean; + isTurnInProgress?: boolean; + isNetworkActive?: boolean; isDisabled?: boolean; languageModels: LanguageModelInfo[]; selectedSearchScopes: SearchScope[]; @@ -49,7 +50,8 @@ const ChatBoxComponent = ({ preferredSuggestionsBoxPlacement = "bottom-start", className, isRedirecting, - isGenerating, + isTurnInProgress, + isNetworkActive, isDisabled, isLoginWallEnabled, isAuthenticated, @@ -139,7 +141,7 @@ const ChatBoxComponent = ({ } } - if (isGenerating) { + if (isTurnInProgress) { return { isSubmitDisabled: true, isSubmitDisabledReason: "generating", @@ -159,7 +161,7 @@ const ChatBoxComponent = ({ isSubmitDisabledReason: undefined, } - }, [editor.children, isRedirecting, isGenerating, selectedLanguageModel]) + }, [editor.children, isRedirecting, isTurnInProgress, selectedLanguageModel]) const { requiresLogin, @@ -367,7 +369,7 @@ const ChatBoxComponent = ({ className={cn("flex flex-col justify-between gap-0.5 w-full px-3 py-2", className)} > ) : - isGenerating ? ( + isNetworkActive ? ( + + + + + + + e.preventDefault()}> + + + + Connectors + + + {isError && !hasServers ? ( + { + e.preventDefault(); + refetch(); + }} + className="gap-2 text-destructive" + > + + Failed to load. Retry? + + ) : isLoading ? ( + + Loading connectors... + + ) : !hasServers ? ( + + No connectors available + + ) : ( + <> + {connectedServers.map((server) => { + const isEnabled = !server.isAuthExpired && !disabledMcpServerIds.includes(server.id); + return ( + e.preventDefault()} + disabled={server.isAuthExpired} + className="flex items-center justify-between gap-2" + > +
+ {server.isAuthExpired ? ( + + ) : ( + + )} + {server.name} +
+ onToggle(server.id, checked)} + disabled={server.isAuthExpired} + className="scale-75" + /> +
+ ); + })} + {connectedServers.length > 0 && connectableServers.length > 0 && } + {connectableServers.map((server) => ( + { + e.preventDefault(); + void handleConnect(server.id); + }} + disabled={connectingServerId !== null} + className="group flex cursor-pointer items-center justify-between gap-2" + > +
+ + {server.name} +
+ {connectingServerId === server.id ? ( + + ) : ( + + )} +
+ ))} + + )} + + router.push(`/settings/accountAskAgent`)} + > + + My connectors + + router.push(`/settings/workspaceAskAgent`)} + > + + Workspace connectors + +
+
+
+ + ); +}; diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx index a0aae38cf..dc905a768 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx @@ -5,6 +5,7 @@ import { LanguageModelInfo, SearchScope } from "@/features/chat/types"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useSelectedLanguageModel } from "../../useSelectedLanguageModel"; import { AtMentionButton } from "./atMentionButton"; +import { ChatBoxPlusButton } from "./chatBoxPlusButton"; import { LanguageModelSelector } from "./languageModelSelector"; import { SearchScopeSelector } from "./searchScopeSelector"; @@ -16,6 +17,8 @@ export interface ChatBoxToolbarProps { onSelectedSearchScopesChange: (items: SearchScope[]) => void; isContextSelectorOpen: boolean; onContextSelectorOpenChanged: (isOpen: boolean) => void; + disabledMcpServerIds?: string[]; + onDisabledMcpServerIdsChange?: (ids: string[]) => void; } export const ChatBoxToolbar = ({ @@ -26,6 +29,8 @@ export const ChatBoxToolbar = ({ onSelectedSearchScopesChange, isContextSelectorOpen, onContextSelectorOpenChanged, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, }: ChatBoxToolbarProps) => { const { selectedLanguageModel, setSelectedLanguageModel } = useSelectedLanguageModel({ languageModels, @@ -33,6 +38,17 @@ export const ChatBoxToolbar = ({ return ( <> + {disabledMcpServerIds !== undefined && onDisabledMcpServerIdsChange !== undefined && ( + <> + + + + )} { + return ( +
+
+ +

Extra Features

+
+
+ Add connectors, include files and more. +
+
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/features/chat/components/chatThread/chatThread.tsx index 9394e62d0..f1fbb26da 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThread.tsx @@ -5,12 +5,12 @@ import { Button } from '@/components/ui/button'; import { Separator } from '@/components/ui/separator'; import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; -import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; +import { createUIMessage, getAllMentionElements, getTurnProgressState, resetEditor, slateContentToString } from '@/features/chat/utils'; import { useChat } from '@ai-sdk/react'; -import { CreateUIMessage, DefaultChatTransport } from 'ai'; +import { CreateUIMessage, DefaultChatTransport, lastAssistantMessageIsCompleteWithApprovalResponses } from 'ai'; import { ArrowDownIcon, CopyIcon } from 'lucide-react'; import { useNavigationGuard } from 'next-navigation-guard'; -import { Fragment, useCallback, useEffect, useRef, useState } from 'react'; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useStickToBottom } from 'use-stick-to-bottom'; import { Descendant } from 'slate'; import { useMessagePairs } from '../../useMessagePairs'; @@ -19,12 +19,15 @@ import { ChatBox } from '../chatBox'; import { ChatBoxToolbar } from '../chatBox/chatBoxToolbar'; import { ChatThreadListItem } from './chatThreadListItem'; import { ErrorBanner } from './errorBanner'; +import { McpFailedServersBanner } from './mcpFailedServersBanner'; import { useRouter } from 'next/navigation'; import { usePrevious } from '@uidotdev/usehooks'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; import { duplicateChat, generateAndUpdateChatNameFromMessage } from '../../actions'; import { isServiceError } from '@/lib/utils'; import { NotConfiguredErrorBanner } from '../notConfiguredErrorBanner'; +import { McpServerIconContext, McpServerIconMap } from '../../mcpServerIconContext'; +import { ToolApprovalProvider } from '../../toolApprovalContext'; import useCaptureEvent from '@/hooks/useCaptureEvent'; import { SignInPromptBanner } from './signInPromptBanner'; import { DuplicateChatDialog } from '@/app/(app)/chat/components/duplicateChatDialog'; @@ -42,6 +45,8 @@ interface ChatThreadProps { searchContexts: SearchContextQuery[]; selectedSearchScopes: SearchScope[]; onSelectedSearchScopesChange: (items: SearchScope[]) => void; + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; isOwner?: boolean; isAuthenticated: boolean; isLoginWallEnabled: boolean; @@ -57,6 +62,8 @@ export const ChatThread = ({ searchContexts, selectedSearchScopes, onSelectedSearchScopesChange, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, isOwner = true, isAuthenticated, isLoginWallEnabled, @@ -80,13 +87,69 @@ export const ChatThread = ({ ) ?? [] ); + const [mcpServerIconMap, setMcpServerIconMap] = useState(() => { + const map: McpServerIconMap = {}; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-server') + .forEach((part) => { + map[part.data.sanitizedName] = part.data.faviconUrl; + }); + }); + return map; + }); + + const [failedMcpServers, setFailedMcpServers] = useState(() => { + const names: string[] = []; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-failed-server') + .forEach((part) => { + if (!names.includes(part.data.serverName)) { + names.push(part.data.serverName); + } + }); + }); + return names; + }); + const [isFailedMcpBannerVisible, setIsFailedMcpBannerVisible] = useState(false); + const { selectedLanguageModel } = useSelectedLanguageModel({ languageModels, }); + // Refs to capture the latest request params for the transport body. + // The transport is created once (useMemo) but params change over time, + // so refs ensure the dynamic body function always reads current values. + const searchScopesRef = useRef(selectedSearchScopes); + const modelRef = useRef(selectedLanguageModel); + const disabledMcpRef = useRef(disabledMcpServerIds); + + useEffect(() => { searchScopesRef.current = selectedSearchScopes; }, [selectedSearchScopes]); + useEffect(() => { modelRef.current = selectedLanguageModel; }, [selectedLanguageModel]); + useEffect(() => { disabledMcpRef.current = disabledMcpServerIds; }, [disabledMcpServerIds]); + + const getTransportBody = useCallback(() => ({ + selectedSearchScopes: searchScopesRef.current, + languageModel: modelRef.current, + disabledMcpServerIds: disabledMcpRef.current, + }), []); + + // Transport with dynamic body, resolved on every request, including auto-resends + // triggered by sendAutomaticallyWhen after tool approval. + // eslint-disable-next-line react-hooks/refs -- DefaultChatTransport stores the body callback and invokes it during requests, not during render. + const transport = useMemo(() => new DefaultChatTransport({ + api: '/api/chat', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: getTransportBody, + }), [getTransportBody]); + const { messages, sendMessage: _sendMessage, + addToolApprovalResponse, error, status, stop, @@ -94,17 +157,28 @@ export const ChatThread = ({ } = useChat({ id: defaultChatId, messages: initialMessages, - transport: new DefaultChatTransport({ - api: '/api/chat', - headers: { - 'X-Sourcebot-Client-Source': 'sourcebot-web-client', - }, - }), + transport, + sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithApprovalResponses, onData: (dataPart) => { // Keeps sources added by the assistant in sync. if (dataPart.type === 'data-source') { setSources((prev) => [...prev, dataPart.data]); } + if (dataPart.type === 'data-mcp-server') { + setMcpServerIconMap((prev) => ({ + ...prev, + [dataPart.data.sanitizedName]: dataPart.data.faviconUrl, + })); + } + if (dataPart.type === 'data-mcp-failed-server') { + setFailedMcpServers((prev) => { + if (prev.includes(dataPart.data.serverName)) { + return prev; + } + return [...prev, dataPart.data.serverName]; + }); + setIsFailedMcpBannerVisible(true); + } } }); @@ -127,6 +201,7 @@ export const ChatThread = ({ body: { selectedSearchScopes, languageModel: selectedLanguageModel, + disabledMcpServerIds, } satisfies AdditionalChatRequestParams, }); @@ -156,6 +231,7 @@ export const ChatThread = ({ selectedLanguageModel, _sendMessage, selectedSearchScopes, + disabledMcpServerIds, messages.length, toast, chatId, @@ -164,6 +240,12 @@ export const ChatThread = ({ const messagePairs = useMessagePairs(messages); + const { + isTurnInProgress, + isNetworkActive, + isAwaitingToolApproval, + shouldGuardNavigation, + } = useMemo(() => getTurnProgressState({ messages, status }), [messages, status]); useNavigationGuard({ enabled: ({ type }) => { @@ -175,7 +257,7 @@ export const ChatThread = ({ return false; } - return status === "streaming" || status === "submitted"; + return shouldGuardNavigation; }, confirm: () => window.confirm("You have unsaved changes that will be lost."), }); @@ -270,13 +352,13 @@ export const ChatThread = ({ const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes); + const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes, disabledMcpServerIds); sendMessage(message); scrollToBottom(); resetEditor(editor); - }, [sendMessage, selectedSearchScopes, scrollToBottom]); + }, [sendMessage, selectedSearchScopes, disabledMcpServerIds, scrollToBottom]); const onDuplicate = useCallback(async (newName: string): Promise => { if (!defaultChatId) { @@ -298,7 +380,8 @@ export const ChatThread = ({ }, [defaultChatId, toast, router, captureEvent]); return ( - <> + + {error && ( setIsErrorBannerVisible(false)} /> )} + setIsFailedMcpBannerVisible(false)} + />
{messagePairs.map(([userMessage, assistantMessage], index) => { const isLastPair = index === messagePairs.length - 1; - const isStreaming = isLastPair && (status === "streaming" || status === "submitted"); + const isPairTurnInProgress = isLastPair && isTurnInProgress; + const isPairNetworkActive = isLastPair && isNetworkActive; + const isPairAwaitingToolApproval = isLastPair && isAwaitingToolApproval; // Use a stable key based on user message ID const key = userMessage.id; @@ -333,7 +423,9 @@ export const ChatThread = ({ chatId={chatId} userMessage={userMessage} assistantMessage={assistantMessage} - isStreaming={isStreaming} + isTurnInProgress={isPairTurnInProgress} + isNetworkActive={isPairNetworkActive} + isAwaitingToolApproval={isPairAwaitingToolApproval} sources={sources} /> {index !== messagePairs.length - 1 && ( @@ -348,7 +440,7 @@ export const ChatThread = ({
{ - (!isAtBottom && status === "streaming") && ( + (!isAtBottom && isNetworkActive) && (
@@ -426,6 +521,7 @@ export const ChatThread = ({
)}
- + + ); } diff --git a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx index 0cbd4b264..d05508081 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx @@ -8,9 +8,10 @@ import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRe import scrollIntoView from 'scroll-into-view-if-needed'; import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; import { useExtractReferences } from '../../useExtractReferences'; -import { getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences, tryResolveFileReference } from '../../utils'; +import { getAnswerPartFromAssistantMessage, getLastStepParts, groupMessageIntoSteps, isSBChatToolPart, repairReferences, tryResolveFileReference } from '../../utils'; import { AnswerCard } from './answerCard'; import { DetailsCard } from './detailsCard'; +import { ApprovalRequestedToolPart, ToolApprovalBanner } from './toolApprovalBanner'; import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer'; import { ReferencedSourcesListView } from './referencedSourcesListView'; import isEqual from "fast-deep-equal/react"; @@ -19,7 +20,9 @@ import { ANSWER_TAG } from '../../constants'; interface ChatThreadListItemProps { userMessage: SBChatMessage; assistantMessage?: SBChatMessage; - isStreaming: boolean; + isTurnInProgress: boolean; + isNetworkActive: boolean; + isAwaitingToolApproval: boolean; sources: Source[]; chatId: string; index: number; @@ -28,7 +31,9 @@ interface ChatThreadListItemProps { const ChatThreadListItemComponent = forwardRef(({ userMessage, assistantMessage: _assistantMessage, - isStreaming, + isTurnInProgress, + isNetworkActive, + isAwaitingToolApproval, sources, chatId, index, @@ -39,7 +44,7 @@ const ChatThreadListItemComponent = forwardRef(undefined); const [selectedReference, setSelectedReference] = useState(undefined); - const [isDetailsPanelExpanded, _setIsDetailsPanelExpanded] = useState(isStreaming); + const [isDetailsPanelExpanded, _setIsDetailsPanelExpanded] = useState(isNetworkActive); const hasAutoCollapsed = useRef(false); const userHasManuallyExpanded = useRef(false); @@ -78,8 +83,8 @@ const ChatThreadListItemComponent = forwardRef { - return isStreaming && !answerPart - }, [answerPart, isStreaming]); + return isNetworkActive && !answerPart + }, [answerPart, isNetworkActive]); + + // Extract MCP tool parts that are waiting for user approval. + const approvalRequestedParts = useMemo((): ApprovalRequestedToolPart[] => { + if (!assistantMessage) { + return []; + } + return getLastStepParts(assistantMessage.parts) + .filter(isSBChatToolPart) + .filter((part): part is ApprovalRequestedToolPart => part.state === 'approval-requested'); + }, [assistantMessage]); // Auto-collapse when answer first appears, but only once and respect user preference @@ -331,7 +347,7 @@ const ChatThreadListItemComponent = forwardRef
- {isStreaming ? ( + {isTurnInProgress ? ( ) : ( @@ -359,11 +375,17 @@ const ChatThreadListItemComponent = forwardRef + {approvalRequestedParts.length > 0 && ( + + )} + {(answerPart && assistantMessage) ? ( - ) : !isStreaming && ( + ) : !isTurnInProgress && approvalRequestedParts.length === 0 && (

Error: No answer response was provided

)}
@@ -404,7 +426,7 @@ const ChatThreadListItemComponent = forwardRef - ) : isStreaming ? ( + ) : (isTurnInProgress) ? (
{Array.from({ length: 3 }).map((_, index) => ( @@ -432,15 +454,19 @@ const arePropsEqual = ( prevProps: ChatThreadListItemProps, nextProps: ChatThreadListItemProps ): boolean => { - // Always re-render if streaming status changes - if (prevProps.isStreaming !== nextProps.isStreaming) { + // Always re-render if turn/network/approval status changes + if ( + prevProps.isTurnInProgress !== nextProps.isTurnInProgress || + prevProps.isNetworkActive !== nextProps.isNetworkActive || + prevProps.isAwaitingToolApproval !== nextProps.isAwaitingToolApproval + ) { return false; } - // If currently streaming, always allow re-render + // If currently network-active, always allow re-render // This bypasses the fast-deep-equal reference check issue when useChat // mutates message objects in place during token streaming - if (nextProps.isStreaming) { + if (nextProps.isNetworkActive) { return false; } @@ -466,4 +492,4 @@ const getNearestReferenceElement = (referenceElements: Element[]) => { return currentDistance < nearestDistance ? current : nearest; }); -} \ No newline at end of file +} diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx b/packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx new file mode 100644 index 000000000..6f9c924cc --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx @@ -0,0 +1,122 @@ +import { cleanup, render, screen } from '@testing-library/react'; +import { afterEach, describe, expect, test, vi } from 'vitest'; +import { TooltipProvider } from '@/components/ui/tooltip'; +import { DetailsCard } from './detailsCard'; +import type { SBChatMessagePart } from '../../types'; + +vi.mock('@/hooks/useCaptureEvent', () => ({ + default: () => vi.fn(), +})); + +afterEach(() => { + cleanup(); +}); + +describe('DetailsCard', () => { + test('shows an approval waiting state without final metadata while awaiting permission', () => { + const { container } = render( + + + + ); + + expect(screen.queryByText('Awaiting permission...')).toBeTruthy(); + expect(screen.queryByText('Thinking...')).toBeNull(); + expect(container.querySelector('.lucide-shield-question-mark')).toBeTruthy(); + expect(container.querySelector('.lucide-loader-circle')).toBeNull(); + expect(container.querySelector('.animate-spin')).toBeNull(); + expect(screen.queryByText('Claude Sonnet')).toBeNull(); + expect(screen.queryByText('41k tokens')).toBeNull(); + }); + + test('shows a spinner while thinking instead of the approval waiting icon', () => { + const { container } = render( + + + + ); + + expect(screen.queryByText('Thinking...')).toBeTruthy(); + expect(screen.queryByText('Awaiting permission...')).toBeNull(); + expect(container.querySelector('.lucide-loader-circle')).toBeTruthy(); + expect(container.querySelector('.animate-spin')).toBeTruthy(); + expect(container.querySelector('.lucide-shield-question-mark')).toBeNull(); + }); + + test('shows final details metadata only after the turn is complete', () => { + render( + + + + ); + + expect(screen.queryByText('Details')).toBeTruthy(); + expect(screen.queryByText('Claude Sonnet')).toBeTruthy(); + expect(screen.queryByText('41k tokens')).toBeTruthy(); + }); + + test('shows terminal tool activation failures instead of a loading state', () => { + const failedActivationPart = { + type: 'tool-tool_request_activation', + toolCallId: 'tool-call-1', + state: 'output-error', + input: { tool_to_activate_name: 'mcp_linear__search_issues' }, + errorText: 'Activation failed', + } satisfies SBChatMessagePart; + + render( + + + + ); + + expect(screen.queryByText('Tool activation failed: Activation failed')).toBeTruthy(); + expect(screen.queryByText('Activating tool...')).toBeNull(); + }); +}); diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx index 0e2365ea6..cd6d8228d 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx +++ b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx @@ -9,7 +9,7 @@ import useCaptureEvent from '@/hooks/useCaptureEvent'; import { cn, getShortenedNumberDisplayString } from '@/lib/utils'; import isEqual from "fast-deep-equal/react"; import { useStickToBottom } from 'use-stick-to-bottom'; -import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, ScanSearchIcon, Wrench, Zap } from 'lucide-react'; +import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, ScanSearchIcon, ShieldQuestion, Wrench, Zap } from 'lucide-react'; import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { usePrevious } from '@uidotdev/usehooks'; import { SBChatMessageMetadata, SBChatMessagePart } from '../../types'; @@ -25,6 +25,8 @@ import { ListReposToolComponent } from './tools/listReposToolComponent'; import { ListTreeToolComponent } from './tools/listTreeToolComponent'; import { ReadFileToolComponent } from './tools/readFileToolComponent'; import { ToolOutputGuard } from './tools/toolOutputGuard'; +import { McpToolComponent } from './tools/mcpToolComponent'; +import { ToolSearchToolComponent } from './tools/toolSearchToolComponent'; interface DetailsCardProps { @@ -32,7 +34,9 @@ interface DetailsCardProps { isExpanded: boolean; onExpandedChanged: (isExpanded: boolean) => void; isThinking: boolean; - isStreaming: boolean; + isTurnInProgress: boolean; + isNetworkActive: boolean; + isAwaitingToolApproval: boolean; thinkingSteps: SBChatMessagePart[][]; metadata?: SBChatMessageMetadata; } @@ -42,13 +46,18 @@ const DetailsCardComponent = ({ isExpanded, onExpandedChanged, isThinking, - isStreaming, + isTurnInProgress, + isNetworkActive, + isAwaitingToolApproval, metadata, thinkingSteps, }: DetailsCardProps) => { const captureEvent = useCaptureEvent(); - const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => part.type.startsWith('tool-')).length, [thinkingSteps]); + const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => + part.type.startsWith('tool-') || + (part.type === 'dynamic-tool' && part.toolName.startsWith('mcp_')) + ).length, [thinkingSteps]); const handleExpandedChanged = useCallback((next: boolean) => { captureEvent('wa_chat_details_card_toggled', { chatId, isExpanded: next }); @@ -74,6 +83,11 @@ const DetailsCardComponent = ({ Thinking... + ) : isAwaitingToolApproval ? ( + <> + + Awaiting permission... + ) : ( <> @@ -81,7 +95,7 @@ const DetailsCardComponent = ({ )}

- {!isStreaming && ( + {!isTurnInProgress && ( <> {(metadata?.selectedSearchScopes && metadata.selectedSearchScopes.length > 0) && ( @@ -166,7 +180,7 @@ const DetailsCardComponent = ({ @@ -179,7 +193,7 @@ const DetailsCardComponent = ({ export const DetailsCard = memo(DetailsCardComponent, isEqual); -const ThinkingSteps = ({ thinkingSteps, isStreaming, isThinking }: { thinkingSteps: SBChatMessagePart[][], isStreaming: boolean, isThinking: boolean }) => { +const ThinkingSteps = ({ thinkingSteps, isNetworkActive, isThinking }: { thinkingSteps: SBChatMessagePart[][], isNetworkActive: boolean, isThinking: boolean }) => { const { scrollRef, contentRef, scrollToBottom } = useStickToBottom(); const [shouldStick, setShouldStick] = useState(isThinking); const prevIsThinking = usePrevious(isThinking); @@ -197,7 +211,7 @@ const ThinkingSteps = ({ thinkingSteps, isStreaming, isThinking }: { thinkingSte
{thinkingSteps.length === 0 ? ( - isStreaming ? ( + isNetworkActive ? ( ) : (

No thinking steps

@@ -308,8 +322,22 @@ export const StepPartRenderer = ({ part }: { part: SBChatMessagePart }) => { {(output) => } ) - case 'data-source': + case 'tool-tool_request_activation': + if (part.state === 'output-error') { + return Tool activation failed: {part.errorText}; + } + if (part.state !== 'output-available') { + return Activating tool...; + } + return ; case 'dynamic-tool': + if (part.toolName.startsWith('mcp_')) { + return ; + } + return null; + case 'data-source': + case 'data-mcp-server': + case 'data-mcp-failed-server': case 'file': case 'source-document': case 'source-url': @@ -320,4 +348,4 @@ export const StepPartRenderer = ({ part }: { part: SBChatMessagePart }) => { part satisfies never; return null; } -} \ No newline at end of file +} diff --git a/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx new file mode 100644 index 000000000..0c37b2717 --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx @@ -0,0 +1,44 @@ +'use client'; + +import { Button } from '@/components/ui/button'; +import { AlertTriangle, X } from 'lucide-react'; + +interface McpFailedServersBannerProps { + serverNames: string[]; + isVisible: boolean; + onClose: () => void; +} + +export const McpFailedServersBanner = ({ serverNames, isVisible, onClose }: McpFailedServersBannerProps) => { + if (!isVisible || serverNames.length === 0) { + return null; + } + + const message = serverNames.length === 1 + ? `Connector "${serverNames[0]}" failed to load tools` + : `${serverNames.length} connectors failed to load tools`; + + return ( +
+
+
+
+ + + {message} + +
+ +
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx b/packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx index 1d1ba36fb..ec4d690e1 100644 --- a/packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx +++ b/packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx @@ -14,7 +14,7 @@ interface SignInPromptBannerProps { isAuthenticated: boolean; isOwner: boolean; hasMessages: boolean; - isStreaming: boolean; + isTurnInProgress: boolean; } export const SignInPromptBanner = ({ @@ -22,7 +22,7 @@ export const SignInPromptBanner = ({ isAuthenticated, isOwner, hasMessages, - isStreaming, + isTurnInProgress, }: SignInPromptBannerProps) => { const pathname = usePathname(); const [isDismissed, setIsDismissed] = useState(true); // Start as true to avoid flash @@ -39,7 +39,7 @@ export const SignInPromptBanner = ({ !isAuthenticated && isOwner && hasMessages && - !isStreaming; + !isTurnInProgress; // Show the banner after first response completes and track display useEffect(() => { diff --git a/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx new file mode 100644 index 000000000..636c951f9 --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx @@ -0,0 +1,114 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { useToolApproval } from "@/features/chat/toolApprovalContext"; +import { SBChatToolPart } from "@/features/chat/utils"; +import { cn } from "@/lib/utils"; +import { getToolName } from "ai"; +import { ChevronRight } from "lucide-react"; +import { useCallback, useState } from "react"; +import { parseMcpToolName } from "./tools/mcpToolComponent"; +import { JsonHighlighter } from "./tools/jsonHighlighter"; + +export type ApprovalRequestedToolPart = SBChatToolPart & { + state: 'approval-requested'; +}; + +interface ToolApprovalBannerProps { + parts: ApprovalRequestedToolPart[]; +} + +export const ToolApprovalBanner = ({ parts }: ToolApprovalBannerProps) => { + const addToolApprovalResponse = useToolApproval(); + const iconMap = useMcpServerIconMap(); + + if (parts.length === 0) { + return null; + } + + return ( +
+ {parts.map((part) => ( + + ))} +
+ ); +}; + +const ToolApprovalItem = ({ + part, + addToolApprovalResponse, + iconMap, +}: { + part: ApprovalRequestedToolPart; + addToolApprovalResponse: ReturnType; + iconMap: Record; +}) => { + const [isExpanded, setIsExpanded] = useState(false); + const partToolName = getToolName(part); + const parsed = parseMcpToolName(partToolName); + const serverName = parsed?.serverName ?? partToolName; + const toolName = parsed?.toolName ?? partToolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const requestText = JSON.stringify(part.input, null, 2); + + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const onApprove = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: true }); + } + }, [part, addToolApprovalResponse]); + + const onDeny = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: false, reason: 'User denied' }); + } + }, [part, addToolApprovalResponse]); + + return ( +
+
+ +
+ + +
+
+ {isExpanded && ( +
+ +
+ )} +
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx new file mode 100644 index 000000000..18203a9de --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx @@ -0,0 +1,151 @@ +'use client'; + +export function unescapeJsonStrings(value: unknown): unknown { + if (typeof value === 'string') { + try { + const parsed: unknown = JSON.parse(value); + if (typeof parsed === 'object' && parsed !== null) { + return unescapeJsonStrings(parsed); + } + } catch { + // not JSON — leave as-is + } + return value; + } + if (Array.isArray(value)) { + return value.map(unescapeJsonStrings); + } + if (typeof value === 'object' && value !== null) { + return Object.fromEntries( + Object.entries(value).map(([k, v]) => [k, unescapeJsonStrings(v)]) + ); + } + return value; +} + +type TokenType = 'key' | 'string' | 'number' | 'boolean' | 'null' | 'structural' | 'whitespace' | 'other'; + +interface Token { + type: TokenType; + value: string; +} + +function tokenizeJson(text: string): Token[] { + const tokens: Token[] = []; + let i = 0; + + while (i < text.length) { + const ch = text[i]; + + // Whitespace + if (/\s/.test(ch)) { + let j = i + 1; + while (j < text.length && /\s/.test(text[j])) { + j++; + } + tokens.push({ type: 'whitespace', value: text.slice(i, j) }); + i = j; + continue; + } + + // String + if (ch === '"') { + let j = i + 1; + while (j < text.length) { + if (text[j] === '\\') { + j += 2; + } else if (text[j] === '"') { + j++; + break; + } else { + j++; + } + } + const str = text.slice(i, j); + + // Lookahead past whitespace for a colon → this is a key + let k = j; + while (k < text.length && /\s/.test(text[k])) { + k++; + } + const isKey = text[k] === ':'; + + tokens.push({ type: isKey ? 'key' : 'string', value: str }); + i = j; + continue; + } + + // Number + if (ch === '-' || /\d/.test(ch)) { + const match = text.slice(i).match(/^-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/); + if (match) { + tokens.push({ type: 'number', value: match[0] }); + i += match[0].length; + continue; + } + } + + // Boolean / null keywords + if (text.slice(i, i + 4) === 'true') { + tokens.push({ type: 'boolean', value: 'true' }); + i += 4; + continue; + } + if (text.slice(i, i + 5) === 'false') { + tokens.push({ type: 'boolean', value: 'false' }); + i += 5; + continue; + } + if (text.slice(i, i + 4) === 'null') { + tokens.push({ type: 'null', value: 'null' }); + i += 4; + continue; + } + + // Structural characters + if ('{}[]:,'.includes(ch)) { + tokens.push({ type: 'structural', value: ch }); + i++; + continue; + } + + // Fallback + tokens.push({ type: 'other', value: ch }); + i++; + } + + return tokens; +} + +const TOKEN_CLASSES: Record = { + key: 'text-editor-tag-name', + string: 'text-editor-tag-string', + number: 'text-editor-tag-number', + boolean: 'text-editor-tag-atom', + null: 'text-editor-tag-atom', + structural: 'text-muted-foreground', + whitespace: '', + other: '', +}; + +import { useMemo } from "react"; + +export const JsonHighlighter = ({ text }: { text: string }) => { + const tokens = useMemo(() => tokenizeJson(text), [text]); + + return ( +
+            {tokens.map((token, i) => {
+                const cls = TOKEN_CLASSES[token.type];
+                if (!cls) {
+                    return token.value;
+                }
+                return (
+                    
+                        {token.value}
+                    
+                );
+            })}
+        
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx new file mode 100644 index 000000000..aeca09156 --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx @@ -0,0 +1,173 @@ +'use client'; + +import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; +import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { cn } from "@/lib/utils"; +import { DynamicToolUIPart } from "ai"; +import { CheckCircle, ChevronDown, XCircle } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; + +export function parseMcpToolName(toolName: string): { serverName: string; toolName: string } | null { + if (!toolName.startsWith('mcp_')) { + return null; + } + const withoutPrefix = toolName.slice(4); + const doubleUnderscoreIdx = withoutPrefix.indexOf('__'); + if (doubleUnderscoreIdx === -1) { + return null; + } + return { + serverName: withoutPrefix.slice(0, doubleUnderscoreIdx), + toolName: withoutPrefix.slice(doubleUnderscoreIdx + 2), + }; +} + +export const McpToolComponent = ({ part }: { part: DynamicToolUIPart }) => { + const needsApproval = part.state === 'approval-requested'; + const [isExpanded, setIsExpanded] = useState(needsApproval); + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const iconMap = useMcpServerIconMap(); + const parsed = parseMcpToolName(part.toolName); + const displayName = parsed + ? `${parsed.serverName}: ${parsed.toolName}` + : part.toolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const hasInput = part.state !== 'input-streaming'; + + const requestText = useMemo( + () => hasInput ? JSON.stringify(part.input, null, 2) : '', + [hasInput, part.input] + ); + const responseText = useMemo(() => { + if (part.state === 'output-available') { + try { + return JSON.stringify(unescapeJsonStrings(part.output), null, 2); + } catch { + return String(part.output); + } + } + if (part.state === 'output-error') { + return part.errorText ?? ''; + } + return undefined; + }, [part.state, part.output, part.errorText]); + + const onCopyRequest = useCallback(() => { + navigator.clipboard.writeText(requestText); + return true; + }, [requestText]); + + const onCopyResponse = useCallback(() => { + if (!responseText) { + return false; + } + navigator.clipboard.writeText(responseText); + return true; + }, [responseText]); + + const renderStatus = () => { + if (part.state === 'output-error') { + return ( + + + {displayName} failed: {part.errorText} + + ); + } + if (part.state === 'output-denied') { + return ( + + + + {displayName} — denied + + ); + } + if (part.state === 'approval-requested') { + return ( + + + {displayName} + + ); + } + if (part.state === 'approval-responded') { + const approved = part.approval.approved; + return ( + + + {approved ? : } + {displayName}{approved ? '...' : ' — denied'} + + ); + } + if (part.state === 'output-available') { + return ( + + + {displayName} + + ); + } + // input-streaming, input-available, or other in-progress states + return ( + + + {displayName}... + + ); + }; + + return ( +
+
+
+ {renderStatus()} +
+ {hasInput && ( + + )} +
+ {hasInput && isExpanded && ( +
+ + + + {responseText !== undefined && ( + <> +
+ +
+ +
+
+ + )} +
+ )} +
+ ); +}; + + +const ResultSection = ({ label, onCopy, children }: { label: string; onCopy: () => boolean; children: React.ReactNode }) => ( +
+
+ {label} + +
+
+ {children} +
+
+); diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx index aac756f4a..43ce2021d 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx +++ b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx @@ -6,6 +6,7 @@ import { ToolUIPart } from "ai"; import { ChevronDown } from "lucide-react"; import { cn } from "@/lib/utils"; import { useCallback, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; export const ToolOutputGuard = >({ part, @@ -27,7 +28,7 @@ export const ToolOutputGuard = { const raw = (part.output as { output: string }).output; try { - return JSON.stringify(JSON.parse(raw), null, 2); + return JSON.stringify(unescapeJsonStrings(JSON.parse(raw)), null, 2); } catch { return raw; } @@ -70,17 +71,15 @@ export const ToolOutputGuard = -
-                            {requestText}
-                        
+
{responseText !== undefined && ( <>
-
-                                    {responseText}
-                                
+
+ +
)} diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx new file mode 100644 index 000000000..545ed9b7f --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx @@ -0,0 +1,52 @@ +'use client'; + +import { Separator } from "@/components/ui/separator"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { ChevronRight } from "lucide-react"; +import { useState } from "react"; +import { cn } from "@/lib/utils"; + +interface ToolSearchResult { + name: string; + description: string; +} + +interface ToolSearchToolComponentProps { + query: string; + results: ToolSearchResult[]; +} + +export const ToolSearchToolComponent = ({ query, results }: ToolSearchToolComponentProps) => { + const [isOpen, setIsOpen] = useState(false); + + return ( + + +
+ + Searched connector tools: {query} + + {results.length} result{results.length === 1 ? '' : 's'} +
+
+ +
+ {results.map((result) => ( +
+ {result.name} + {result.description && ( + <> + - + {result.description} + + )} +
+ ))} + {results.length === 0 && ( + No tools found + )} +
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/constants.ts b/packages/web/src/features/chat/constants.ts index 1038852c6..db518d2aa 100644 --- a/packages/web/src/features/chat/constants.ts +++ b/packages/web/src/features/chat/constants.ts @@ -10,3 +10,5 @@ export const ANSWER_TAG = ''; export const SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY = 'selectedSearchScopes'; export const SET_CHAT_STATE_SESSION_STORAGE_KEY = 'setChatState'; export const PENDING_CHAT_SUBMISSION_SESSION_STORAGE_KEY = 'pendingChatSubmission'; +export const DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY = 'disabledMcpServerIds'; +export const MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY = 'mcpOAuthDraft'; diff --git a/packages/web/src/features/chat/mcpOAuthDraft.test.ts b/packages/web/src/features/chat/mcpOAuthDraft.test.ts new file mode 100644 index 000000000..6f81f644e --- /dev/null +++ b/packages/web/src/features/chat/mcpOAuthDraft.test.ts @@ -0,0 +1,84 @@ +import { beforeEach, describe, expect, test } from 'vitest'; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from './constants'; +import { + consumeMcpOAuthDraftForPath, + normalizeMcpOAuthDraftPath, + resolveMcpOAuthDraftForPath, + saveMcpOAuthDraft, +} from './mcpOAuthDraft'; +import type { Descendant } from 'slate'; +import type { SearchScope } from './types'; + +const children = [{ + type: 'paragraph', + children: [{ text: 'check the Linear ticket' }], +}] satisfies Descendant[]; + +const selectedSearchScopes = [{ + type: 'repo', + value: 'sourcebot/sourcebot', + name: 'sourcebot/sourcebot', + codeHostType: 'github', +}] satisfies SearchScope[]; + +const draft = { + returnTo: '/chat/thread-1?scope=sourcebot', + children, + selectedSearchScopes, + disabledMcpServerIds: ['server-disabled'], + createdAt: 100, +}; + +describe('MCP OAuth draft persistence', () => { + beforeEach(() => { + sessionStorage.clear(); + }); + + test('normalizes chat paths and strips OAuth status params', () => { + expect(normalizeMcpOAuthDraftPath('/chat/thread-1?scope=sourcebot&status=connected&server=Linear')).toBe('/chat/thread-1?scope=sourcebot'); + expect(normalizeMcpOAuthDraftPath('/settings/accountAskAgent')).toBeUndefined(); + expect(normalizeMcpOAuthDraftPath('https://evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthDraftPath('//evil.example.com/chat')).toBeUndefined(); + }); + + test('resolves a draft for the same chat path after the OAuth callback adds status params', () => { + const result = resolveMcpOAuthDraftForPath( + JSON.stringify(draft), + '/chat/thread-1?scope=sourcebot&status=connected&server=Linear', + 200, + ); + + expect(result.shouldClear).toBe(true); + expect(result.draft).toEqual(draft); + }); + + test('keeps a draft when the current chat path does not match', () => { + const result = resolveMcpOAuthDraftForPath(JSON.stringify(draft), '/chat/thread-2', 200); + + expect(result.shouldClear).toBe(false); + expect(result.draft).toBeUndefined(); + }); + + test('clears invalid and stale drafts', () => { + expect(resolveMcpOAuthDraftForPath('{', '/chat/thread-1').shouldClear).toBe(true); + expect(resolveMcpOAuthDraftForPath(JSON.stringify({ ...draft, children: [1] }), '/chat/thread-1?scope=sourcebot', 200).shouldClear).toBe(true); + expect(resolveMcpOAuthDraftForPath(JSON.stringify(draft), '/chat/thread-1?scope=sourcebot', 30 * 60 * 1000 + 101).shouldClear).toBe(true); + }); + + test('saves and consumes the composer draft from sessionStorage', () => { + saveMcpOAuthDraft({ + returnTo: '/chat/thread-1?scope=sourcebot&status=error', + children, + selectedSearchScopes, + disabledMcpServerIds: ['server-disabled'], + }); + + const restoredDraft = consumeMcpOAuthDraftForPath('/chat/thread-1?scope=sourcebot&status=connected&server=Linear'); + + expect(restoredDraft?.returnTo).toBe('/chat/thread-1?scope=sourcebot'); + expect(restoredDraft?.children).toEqual(children); + expect(restoredDraft?.selectedSearchScopes).toEqual(selectedSearchScopes); + expect(restoredDraft?.disabledMcpServerIds).toEqual(['server-disabled']); + expect(sessionStorage.getItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY)).toBeNull(); + }); +}); diff --git a/packages/web/src/features/chat/mcpOAuthDraft.ts b/packages/web/src/features/chat/mcpOAuthDraft.ts new file mode 100644 index 000000000..bbbf2a146 --- /dev/null +++ b/packages/web/src/features/chat/mcpOAuthDraft.ts @@ -0,0 +1,217 @@ +import type { Descendant } from "slate"; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from "./constants"; +import type { CustomText, MentionElement, ParagraphElement, SearchScope } from "./types"; + +const MCP_OAUTH_DRAFT_BASE_URL = 'https://sourcebot.invalid'; +const MCP_OAUTH_DRAFT_MAX_AGE_MS = 30 * 60 * 1000; +const MCP_OAUTH_STATUS_PARAMS = ['status', 'server', 'message']; + +export interface McpOAuthDraft { + returnTo: string; + children: Descendant[]; + selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; + createdAt: number; +} + +type McpOAuthDraftInput = Omit; + +interface ResolveMcpOAuthDraftResult { + draft?: McpOAuthDraft; + shouldClear: boolean; +} + +function isAllowedMcpOAuthDraftPath(pathname: string): boolean { + return pathname === '/chat' || pathname.startsWith('/chat/'); +} + +function isRecord(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +function isCustomText(value: unknown): value is CustomText { + return isRecord(value) && typeof value.text === 'string'; +} + +function isMentionElement(value: unknown): value is MentionElement { + return ( + isRecord(value) && + value.type === 'mention' && + isRecord(value.data) && + value.data.type === 'file' && + typeof value.data.repo === 'string' && + typeof value.data.path === 'string' && + typeof value.data.name === 'string' && + typeof value.data.language === 'string' && + typeof value.data.revision === 'string' && + Array.isArray(value.children) && + value.children.every(isCustomText) + ); +} + +function isParagraphElement(value: unknown): value is ParagraphElement { + return ( + isRecord(value) && + value.type === 'paragraph' && + (value.align === undefined || typeof value.align === 'string') && + Array.isArray(value.children) && + value.children.length > 0 && + value.children.every((child) => isCustomText(child) || isMentionElement(child)) + ); +} + +function isMcpOAuthDraftChildren(value: unknown): value is Descendant[] { + return Array.isArray(value) && value.length > 0 && value.every(isParagraphElement); +} + +export function normalizeMcpOAuthDraftPath(path: string): string | undefined { + const trimmedPath = path.trim(); + if (!trimmedPath || !trimmedPath.startsWith('/') || trimmedPath.startsWith('//') || trimmedPath.includes('\\')) { + return undefined; + } + + try { + const url = new URL(trimmedPath, MCP_OAUTH_DRAFT_BASE_URL); + if (url.origin !== MCP_OAUTH_DRAFT_BASE_URL || !isAllowedMcpOAuthDraftPath(url.pathname)) { + return undefined; + } + + for (const param of MCP_OAUTH_STATUS_PARAMS) { + url.searchParams.delete(param); + } + + const query = url.searchParams.toString(); + return `${url.pathname}${query ? `?${query}` : ''}`; + } catch { + return undefined; + } +} + +export function createMcpOAuthDraftPath(pathname: string, search: string): string | undefined { + return normalizeMcpOAuthDraftPath(`${pathname}${search}`); +} + +function isMcpOAuthDraft(value: unknown): value is McpOAuthDraft { + return ( + isRecord(value) && + 'returnTo' in value && + typeof value.returnTo === 'string' && + 'children' in value && + isMcpOAuthDraftChildren(value.children) && + 'selectedSearchScopes' in value && + Array.isArray(value.selectedSearchScopes) && + 'disabledMcpServerIds' in value && + Array.isArray(value.disabledMcpServerIds) && + value.disabledMcpServerIds.every((id) => typeof id === 'string') && + 'createdAt' in value && + typeof value.createdAt === 'number' + ); +} + +export function resolveMcpOAuthDraftForPath( + storedDraft: string | null, + currentPath: string, + now = Date.now(), +): ResolveMcpOAuthDraftResult { + if (!storedDraft) { + return { shouldClear: false }; + } + + let parsedDraft: unknown; + try { + parsedDraft = JSON.parse(storedDraft); + } catch { + return { shouldClear: true }; + } + + if (!isMcpOAuthDraft(parsedDraft)) { + return { shouldClear: true }; + } + + if (now - parsedDraft.createdAt > MCP_OAUTH_DRAFT_MAX_AGE_MS) { + return { shouldClear: true }; + } + + const storedPath = normalizeMcpOAuthDraftPath(parsedDraft.returnTo); + if (!storedPath) { + return { shouldClear: true }; + } + + const normalizedCurrentPath = normalizeMcpOAuthDraftPath(currentPath); + if (!normalizedCurrentPath) { + return { shouldClear: false }; + } + + if (storedPath !== normalizedCurrentPath) { + return { shouldClear: false }; + } + + return { + draft: { + ...parsedDraft, + returnTo: storedPath, + }, + shouldClear: true, + }; +} + +function getSessionStorage(): Storage | undefined { + if (typeof window === 'undefined') { + return undefined; + } + + try { + return window.sessionStorage; + } catch { + return undefined; + } +} + +export function saveMcpOAuthDraft(draft: McpOAuthDraftInput): void { + const storage = getSessionStorage(); + const returnTo = normalizeMcpOAuthDraftPath(draft.returnTo); + if (!storage || !returnTo) { + return; + } + + try { + storage.setItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY, JSON.stringify({ + ...draft, + returnTo, + createdAt: Date.now(), + } satisfies McpOAuthDraft)); + } catch { + // If sessionStorage is unavailable or full, OAuth should still proceed. + } +} + +export function clearMcpOAuthDraft(): void { + const storage = getSessionStorage(); + if (!storage) { + return; + } + + try { + storage.removeItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY); + } catch { + // Ignore storage cleanup failures. + } +} + +export function consumeMcpOAuthDraftForPath(currentPath: string): McpOAuthDraft | undefined { + const storage = getSessionStorage(); + if (!storage) { + return undefined; + } + + const result = resolveMcpOAuthDraftForPath( + storage.getItem(MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY), + currentPath, + ); + + if (result.shouldClear) { + clearMcpOAuthDraft(); + } + + return result.draft; +} diff --git a/packages/web/src/features/chat/mcpServerIconContext.tsx b/packages/web/src/features/chat/mcpServerIconContext.tsx new file mode 100644 index 000000000..94628f4a5 --- /dev/null +++ b/packages/web/src/features/chat/mcpServerIconContext.tsx @@ -0,0 +1,10 @@ +'use client'; + +import { createContext, useContext } from 'react'; + +// Maps sanitized server name (e.g. "linear") to a favicon URL. +export type McpServerIconMap = Record; + +export const McpServerIconContext = createContext({}); + +export const useMcpServerIconMap = () => useContext(McpServerIconContext); diff --git a/packages/web/src/features/chat/toolApprovalContext.tsx b/packages/web/src/features/chat/toolApprovalContext.tsx new file mode 100644 index 000000000..d4379c394 --- /dev/null +++ b/packages/web/src/features/chat/toolApprovalContext.tsx @@ -0,0 +1,9 @@ +'use client'; + +import { createContext, useContext } from 'react'; +import type { ChatAddToolApproveResponseFunction } from 'ai'; + +const ToolApprovalContext = createContext(null); + +export const ToolApprovalProvider = ToolApprovalContext.Provider; +export const useToolApproval = () => useContext(ToolApprovalContext); \ No newline at end of file diff --git a/packages/web/src/features/chat/types.test.ts b/packages/web/src/features/chat/types.test.ts new file mode 100644 index 000000000..a9f41df7c --- /dev/null +++ b/packages/web/src/features/chat/types.test.ts @@ -0,0 +1,72 @@ +import { expect, test, describe } from 'vitest'; +import { sbChatMessageMetadataSchema, additionalChatRequestParamsSchema } from './types'; + +describe('sbChatMessageMetadataSchema', () => { + test('accepts disabledMcpServerIds as array of strings', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: ['id1', 'id2'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['id1', 'id2']); + } + }); + + test('accepts missing disabledMcpServerIds (optional)', () => { + const result = sbChatMessageMetadataSchema.safeParse({}); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toBeUndefined(); + } + }); + + test('rejects non-string array values', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: [123, 456], + }); + + expect(result.success).toBe(false); + }); +}); + +describe('additionalChatRequestParamsSchema', () => { + const validBase = { + languageModel: { + provider: 'anthropic', + model: 'claude-sonnet-4-20250514', + }, + selectedSearchScopes: [], + }; + + test('defaults disabledMcpServerIds to empty array', () => { + const result = additionalChatRequestParamsSchema.safeParse(validBase); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual([]); + } + }); + + test('accepts explicit disabledMcpServerIds array', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: ['abc'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['abc']); + } + }); + + test('rejects non-array value for disabledMcpServerIds', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: 'not-an-array', + }); + + expect(result.success).toBe(false); + }); +}); diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index 6e990f5c2..3c2619f14 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -60,6 +60,7 @@ export const sbChatMessageMetadataSchema = z.object({ userId: z.string().optional(), })).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(), + disabledMcpServerIds: z.array(z.string()).optional(), traceId: z.string().optional(), }); @@ -67,12 +68,22 @@ export type SBChatMessageMetadata = z.infer; export type SBChatMessageToolTypes = { [K in keyof ReturnType]: InferUITool[K]>; +} & { + tool_request_activation: { + input: { tool_to_activate_name: string }; + output: { results: Array<{ name: string; description: string }> }; + }; }; export type SBChatMessageDataParts = { // The `source` data type allows us to know what sources the LLM saw // during retrieval. "source": Source, + // The `mcp-server` data type carries favicon metadata for connected MCP servers, + // keyed by sanitized server name (e.g. "linear"). + "mcp-server": { sanitizedName: string; faviconUrl: string }, + // The `mcp-failed-server` data type surfaces MCP servers that failed to load their tools. + "mcp-failed-server": { serverName: string }, } export type SBChatMessage = UIMessage< @@ -143,6 +154,7 @@ declare module 'slate' { export type SetChatStatePayload = { inputMessage: CreateUIMessage; selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; } @@ -188,5 +200,6 @@ export type LanguageModelInfo = { export const additionalChatRequestParamsSchema = z.object({ languageModel: languageModelInfoSchema, selectedSearchScopes: z.array(searchScopeSchema), + disabledMcpServerIds: z.array(z.string()).default([]), }); -export type AdditionalChatRequestParams = z.infer; \ No newline at end of file +export type AdditionalChatRequestParams = z.infer; diff --git a/packages/web/src/features/chat/useCreateNewChatThread.ts b/packages/web/src/features/chat/useCreateNewChatThread.ts index 18a5a58b9..f030f186d 100644 --- a/packages/web/src/features/chat/useCreateNewChatThread.ts +++ b/packages/web/src/features/chat/useCreateNewChatThread.ts @@ -10,7 +10,7 @@ import { createChat } from "./actions"; import { isServiceError } from "@/lib/utils"; import { createPathWithQueryParams } from "@/lib/utils"; import { SearchScope, SetChatStatePayload } from "./types"; -import { SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, SET_CHAT_STATE_SESSION_STORAGE_KEY } from "./constants"; +import { DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, SET_CHAT_STATE_SESSION_STORAGE_KEY } from "./constants"; import { useSessionStorage } from "usehooks-ts"; export const useCreateNewChatThread = () => { @@ -19,19 +19,29 @@ export const useCreateNewChatThread = () => { const router = useRouter(); const [, setChatState] = useSessionStorage(SET_CHAT_STATE_SESSION_STORAGE_KEY, null); - const createNewChatThread = useCallback(async (children: Descendant[], overrideSearchScopes?: SearchScope[]) => { + const createNewChatThread = useCallback(async (children: Descendant[], overrideSearchScopes?: SearchScope[], overrideDisabledMcpServerIds?: string[]) => { const text = slateContentToString(children); const mentions = getAllMentionElements(children); let storedScopes: SearchScope[] = []; try { const stored = window.localStorage.getItem(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY); - if (stored) storedScopes = JSON.parse(stored) as SearchScope[]; + if (stored) { + storedScopes = JSON.parse(stored) as SearchScope[]; + } } catch { /* fall through to [] */ } - const selectedSearchScopes = overrideSearchScopes ?? storedScopes; + let storedDisabledMcpServerIds: string[] = []; + try { + const stored = window.localStorage.getItem(DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY); + if (stored) { + storedDisabledMcpServerIds = JSON.parse(stored) as string[]; + } + } catch { /* fall through to [] */ } - const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); + const selectedSearchScopes = overrideSearchScopes ?? storedScopes; + const disabledMcpServerIds = overrideDisabledMcpServerIds ?? storedDisabledMcpServerIds; + const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes, disabledMcpServerIds); setIsLoading(true); const response = await createChat({ source: 'sourcebot-web-client' }); @@ -46,6 +56,7 @@ export const useCreateNewChatThread = () => { setChatState({ inputMessage, selectedSearchScopes, + disabledMcpServerIds, }); const url = createPathWithQueryParams(`/chat/${response.id}`); diff --git a/packages/web/src/features/chat/utils.test.ts b/packages/web/src/features/chat/utils.test.ts index 26359d2a9..8f0a77b82 100644 --- a/packages/web/src/features/chat/utils.test.ts +++ b/packages/web/src/features/chat/utils.test.ts @@ -1,5 +1,5 @@ -import { expect, test, vi } from 'vitest' -import { fileReferenceToString, getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from './utils' +import { expect, test, describe, vi } from 'vitest' +import { createUIMessage, fileReferenceToString, getAnswerPartFromAssistantMessage, getLastStepParts, getTurnProgressState, groupMessageIntoSteps, repairReferences } from './utils' import { FILE_REFERENCE_REGEX, ANSWER_TAG } from './constants'; import { SBChatMessage, SBChatMessagePart } from './types'; @@ -10,6 +10,95 @@ vi.mock('@sourcebot/shared', () => ({ } })); +const createAssistantMessage = (parts: SBChatMessagePart[]): SBChatMessage => ({ + id: 'assistant-message', + role: 'assistant', + parts, +}); + +const createUserMessage = (): SBChatMessage => ({ + id: 'user-message', + role: 'user', + parts: [ + { + type: 'text', + text: 'Hello', + }, + ], +}); + +const dynamicApprovalRequestedPart = { + type: 'dynamic-tool', + toolName: 'mcp_linear__save_issue', + toolCallId: 'tool-call-1', + state: 'approval-requested', + input: { title: 'Issue' }, + approval: { id: 'approval-1' }, +} satisfies SBChatMessagePart; + +const dynamicApprovalRespondedPart = { + type: 'dynamic-tool', + toolName: 'mcp_linear__save_issue', + toolCallId: 'tool-call-1', + state: 'approval-responded', + input: { title: 'Issue' }, + approval: { id: 'approval-1', approved: true }, +} satisfies SBChatMessagePart; + +const listReposInput = { + sort: 'name', + page: 1, + perPage: 30, + direction: 'asc', +} as const; + +const listReposOutput = { + output: 'Done', + metadata: { + repos: [], + totalCount: 0, + }, +}; + +const staticApprovalRequestedPart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-2', + state: 'approval-requested', + input: listReposInput, + approval: { id: 'approval-2' }, +} satisfies SBChatMessagePart; + +const staticApprovalRespondedPart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-2', + state: 'approval-responded', + input: listReposInput, + approval: { id: 'approval-2', approved: true }, +} satisfies SBChatMessagePart; + +const outputAvailablePart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-3', + state: 'output-available', + input: listReposInput, + output: listReposOutput, +} satisfies SBChatMessagePart; + +const outputErrorPart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-5', + state: 'output-error', + input: listReposInput, + errorText: 'Tool failed', +} satisfies SBChatMessagePart; + +const inputAvailablePart = { + type: 'tool-list_repos', + toolCallId: 'tool-call-4', + state: 'input-available', + input: listReposInput, +} satisfies SBChatMessagePart; + test('fileReferenceToString formats file references correctly', () => { expect(fileReferenceToString({ @@ -148,7 +237,212 @@ test('groupMessageIntoSteps returns a single group when there is no step-start p ]); }); -test('getAnswerPartFromAssistantMessage returns text part when it starts with ANSWER_TAG while not streaming', () => { +test('getLastStepParts returns the last grouped step', () => { + const parts: SBChatMessagePart[] = [ + { + type: 'step-start', + }, + { + type: 'text', + text: 'First step', + }, + { + type: 'step-start', + }, + { + type: 'text', + text: 'Last step', + }, + ]; + + const lastStep = getLastStepParts(parts); + + expect(lastStep).toEqual([ + { + type: 'step-start', + }, + { + type: 'text', + text: 'Last step', + }, + ]); +}); + +test('getTurnProgressState treats submitted and streaming as in progress and navigation guarded', () => { + expect(getTurnProgressState({ messages: [createUserMessage()], status: 'submitted' })).toMatchObject({ + isNetworkActive: true, + isTurnInProgress: true, + shouldGuardNavigation: true, + }); + expect(getTurnProgressState({ messages: [createUserMessage()], status: 'streaming' })).toMatchObject({ + isNetworkActive: true, + isTurnInProgress: true, + shouldGuardNavigation: true, + }); +}); + +test('getTurnProgressState returns idle for no messages and latest user message when ready', () => { + expect(getTurnProgressState({ messages: [], status: 'ready' })).toMatchObject({ + isNetworkActive: false, + isTurnInProgress: false, + shouldGuardNavigation: false, + }); + expect(getTurnProgressState({ messages: [createUserMessage()], status: 'ready' })).toMatchObject({ + isNetworkActive: false, + isTurnInProgress: false, + shouldGuardNavigation: false, + }); +}); + +test('getTurnProgressState treats latest-step approval-requested as awaiting approval but not navigation guarded', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRequestedPart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'ready' })).toMatchObject({ + hasPendingToolApproval: true, + isAwaitingToolApproval: true, + isTurnInProgress: true, + shouldGuardNavigation: false, + }); +}); + +test('getTurnProgressState treats approval continuation readiness as in progress and navigation guarded', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRespondedPart, + outputAvailablePart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'ready' })).toMatchObject({ + hasApprovalContinuationReady: true, + isAwaitingToolApproval: false, + isTurnInProgress: true, + shouldGuardNavigation: true, + }); +}); + +test('getTurnProgressState treats approval-responded and output-error as continuation-ready', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRespondedPart, + outputErrorPart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'ready' })).toMatchObject({ + hasApprovalContinuationReady: true, + isTurnInProgress: true, + shouldGuardNavigation: true, + }); +}); + +test('getTurnProgressState does not treat a responded approval with non-terminal tools as continuation-ready', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRespondedPart, + inputAvailablePart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'ready' })).toMatchObject({ + hasApprovalContinuationReady: false, + isTurnInProgress: false, + shouldGuardNavigation: false, + }); +}); + +test('getTurnProgressState does not keep terminal tool states in progress', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + outputAvailablePart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'ready' })).toMatchObject({ + hasPendingToolApproval: false, + hasApprovalContinuationReady: false, + isTurnInProgress: false, + }); +}); + +test('getTurnProgressState treats error as not in progress even with pending approval', () => { + const message = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRequestedPart, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), message], status: 'error' })).toMatchObject({ + hasPendingToolApproval: true, + isTurnInProgress: false, + shouldGuardNavigation: false, + }); +}); + +test('getTurnProgressState ignores approvals in older messages and older steps', () => { + const olderMessage = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRequestedPart, + ]); + const latestMessage = createAssistantMessage([ + { + type: 'step-start', + }, + dynamicApprovalRequestedPart, + { + type: 'step-start', + }, + { + type: 'text', + text: 'Later step', + }, + ]); + + expect(getTurnProgressState({ messages: [createUserMessage(), olderMessage, createUserMessage()], status: 'ready' })).toMatchObject({ + isTurnInProgress: false, + }); + expect(getTurnProgressState({ messages: [createUserMessage(), latestMessage], status: 'ready' })).toMatchObject({ + isTurnInProgress: false, + }); +}); + +test('getTurnProgressState classifies dynamic and static tool approvals', () => { + expect(getTurnProgressState({ + messages: [createAssistantMessage([dynamicApprovalRequestedPart])], + status: 'ready', + })).toMatchObject({ + hasPendingToolApproval: true, + isAwaitingToolApproval: true, + }); + expect(getTurnProgressState({ + messages: [createAssistantMessage([staticApprovalRequestedPart])], + status: 'ready', + })).toMatchObject({ + hasPendingToolApproval: true, + isAwaitingToolApproval: true, + }); + expect(getTurnProgressState({ + messages: [createAssistantMessage([staticApprovalRespondedPart])], + status: 'ready', + })).toMatchObject({ + hasApprovalContinuationReady: true, + shouldGuardNavigation: true, + }); +}); + +test('getAnswerPartFromAssistantMessage returns text part when it starts with ANSWER_TAG while turn is complete', () => { const message: SBChatMessage = { role: 'assistant', parts: [ @@ -171,7 +465,7 @@ test('getAnswerPartFromAssistantMessage returns text part when it starts with AN }); }); -test('getAnswerPartFromAssistantMessage returns text part when it starts with ANSWER_TAG while streaming', () => { +test('getAnswerPartFromAssistantMessage returns text part when it starts with ANSWER_TAG while turn is in progress', () => { const message: SBChatMessage = { role: 'assistant', parts: [ @@ -194,7 +488,7 @@ test('getAnswerPartFromAssistantMessage returns text part when it starts with AN }); }); -test('getAnswerPartFromAssistantMessage returns last text part as fallback when not streaming and no ANSWER_TAG', () => { +test('getAnswerPartFromAssistantMessage returns last text part as fallback when turn is complete and no ANSWER_TAG', () => { const message: SBChatMessage = { role: 'assistant', parts: [ @@ -223,7 +517,7 @@ test('getAnswerPartFromAssistantMessage returns last text part as fallback when }); }); -test('getAnswerPartFromAssistantMessage returns undefined when streaming and no ANSWER_TAG', () => { +test('getAnswerPartFromAssistantMessage returns undefined when turn is in progress and no ANSWER_TAG', () => { const message: SBChatMessage = { role: 'assistant', parts: [ @@ -351,3 +645,31 @@ test('repairReferences handles malformed inline code blocks', () => { const expected = 'See @file:{github.com/sourcebot-dev/sourcebot::packages/web/src/auth.ts} for details.'; expect(repairReferences(input)).toBe(expected); }); + +describe('createUIMessage', () => { + test('includes disabledMcpServerIds in metadata when provided', () => { + const result = createUIMessage('hello', [], [], ['server1', 'server2']); + + expect(result.metadata?.disabledMcpServerIds).toEqual(['server1', 'server2']); + }); + + test('defaults disabledMcpServerIds to empty array when omitted', () => { + const result = createUIMessage('hello', [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('passes through empty array', () => { + const result = createUIMessage('hello', [], [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('includes both selectedSearchScopes and disabledMcpServerIds in metadata', () => { + const scopes = [{ type: 'repo' as const, value: 'org/repo', name: 'repo', codeHostType: 'github' }]; + const result = createUIMessage('hello', [], scopes, ['disabled1']); + + expect(result.metadata?.selectedSearchScopes).toEqual(scopes); + expect(result.metadata?.disabledMcpServerIds).toEqual(['disabled1']); + }); +}); diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index 38dd784fd..c7f409ac7 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -1,5 +1,6 @@ import { BrowseHighlightRange, getBrowsePath } from "@/app/(app)/browse/hooks/utils"; -import { CreateUIMessage, TextUIPart, UIMessagePart } from "ai"; +import { CreateUIMessage, isToolUIPart, TextUIPart, UIMessagePart } from "ai"; +import type { ChatStatus, DynamicToolUIPart, ToolUIPart } from "ai"; import { Descendant, Editor, Point, Range, Transforms } from "slate"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX, FILE_REFERENCE_REGEX } from "./constants"; import { @@ -18,6 +19,11 @@ import { Source, } from "./types"; +export type SBChatToolPart = (ToolUIPart | DynamicToolUIPart) & SBChatMessagePart; + +export const isSBChatToolPart = (part: SBChatMessagePart): part is SBChatToolPart => { + return isToolUIPart(part); +}; export const insertMention = (editor: CustomEditor, data: MentionData, target?: Range | null) => { const mention: MentionElement = { @@ -161,11 +167,16 @@ export const getAllMentionElements = (children: Descendant[]): MentionElement[] }); } +export const clearEditorHistory = (editor: CustomEditor) => { + // slate-history exposes `history` publicly, but does not provide a clear API. + editor.history = { redos: [], undos: [] }; +} + // @see: https://stackoverflow.com/a/74102147 export const resetEditor = (editor: CustomEditor) => { const point = { path: [0, 0], offset: 0 } editor.selection = { anchor: point, focus: point }; - editor.history = { redos: [], undos: [] }; + clearEditorHistory(editor); editor.children = [{ type: "paragraph", children: [{ text: "" }] @@ -176,7 +187,7 @@ export const addLineNumbers = (source: string, lineOffset = 1) => { return source.split('\n').map((line, index) => `${index + lineOffset}: ${line}`).join('\n'); } -export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[]): CreateUIMessage => { +export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[] = []): CreateUIMessage => { // Converts applicable mentions into sources. const sources: Source[] = mentions .map((mention) => { @@ -209,6 +220,7 @@ export const createUIMessage = (text: string, mentions: MentionData[], selectedS ], metadata: { selectedSearchScopes, + disabledMcpServerIds, }, } } @@ -319,6 +331,51 @@ export const groupMessageIntoSteps = (parts: SBChatMessagePart[]) => { return steps; } +export const getLastStepParts = (parts: SBChatMessagePart[]): SBChatMessagePart[] => { + return groupMessageIntoSteps(parts).at(-1) ?? parts; +} + +export const getTurnProgressState = ({ + messages, + status, +}: { + messages: SBChatMessage[]; + status: ChatStatus; +}) => { + const isNetworkActive = status === 'submitted' || status === 'streaming'; + const latestMessage = messages.at(-1); + const latestAssistantMessage = latestMessage?.role === 'assistant' ? latestMessage : undefined; + const latestStepToolParts = getLastStepParts(latestAssistantMessage?.parts ?? []) + .filter(isSBChatToolPart); + + const hasPendingToolApproval = latestStepToolParts.some( + (part) => part.state === 'approval-requested' + ); + const hasApprovalContinuationReady = + latestStepToolParts.some((part) => part.state === 'approval-responded') && + latestStepToolParts.every((part) => + part.state === 'output-available' || + part.state === 'output-error' || + part.state === 'approval-responded' + ); + + const isReady = status === 'ready'; + const isTurnInProgress = + isNetworkActive || + (isReady && (hasPendingToolApproval || hasApprovalContinuationReady)); + const isAwaitingToolApproval = isReady && hasPendingToolApproval; + const shouldGuardNavigation = isNetworkActive || (isReady && hasApprovalContinuationReady); + + return { + isNetworkActive, + hasPendingToolApproval, + hasApprovalContinuationReady, + isAwaitingToolApproval, + isTurnInProgress, + shouldGuardNavigation, + }; +} + // LLMs like to not follow instructions... this takes care of some common mistakes they tend to make. export const repairReferences = (text: string): string => { return text @@ -342,7 +399,7 @@ export const repairReferences = (text: string): string => { // Attempts to find the part of the assistant's message // that contains the answer. -export const getAnswerPartFromAssistantMessage = (message: SBChatMessage, isStreaming: boolean): TextUIPart | undefined => { +export const getAnswerPartFromAssistantMessage = (message: SBChatMessage, isTurnInProgress: boolean): TextUIPart | undefined => { const lastTextPart = message.parts .findLast((part) => part.type === 'text') @@ -356,8 +413,8 @@ export const getAnswerPartFromAssistantMessage = (message: SBChatMessage, isStre } // If the agent did not include the answer tag, then fallback to using the last text part. - // Only do this when we are no longer streaming since the agent may still be thinking. - if (!isStreaming && lastTextPart) { + // Only do this when the turn is complete since the agent may still be thinking or waiting. + if (!isTurnInProgress && lastTextPart) { return lastTextPart; } diff --git a/packages/web/src/features/mcp/askCodebase.ts b/packages/web/src/features/mcp/askCodebase.ts index bc3a030c2..94bf4a3f1 100644 --- a/packages/web/src/features/mcp/askCodebase.ts +++ b/packages/web/src/features/mcp/askCodebase.ts @@ -143,6 +143,10 @@ export const askCodebase = (params: AskCodebaseParams): Promise r.value) } : {}), @@ -155,6 +159,7 @@ export const askCodebase = (params: AskCodebaseParams): Promise r.value), + prisma, model, modelName, modelProviderOptions: providerOptions, diff --git a/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts b/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts new file mode 100644 index 000000000..8b3b8a0fe --- /dev/null +++ b/packages/web/src/features/mcp/mcpOAuthReturnTo.test.ts @@ -0,0 +1,36 @@ +import { describe, expect, test } from 'vitest'; +import { + createMcpOAuthState, + getMcpOAuthReturnToFromState, + normalizeMcpOAuthReturnTo, +} from './mcpOAuthReturnTo'; + +describe('MCP OAuth return paths', () => { + test('allows chat return paths', () => { + expect(normalizeMcpOAuthReturnTo('/chat')).toBe('/chat'); + expect(normalizeMcpOAuthReturnTo('/chat/thread-1?foo=bar')).toBe('/chat/thread-1?foo=bar'); + }); + + test('allows connector settings return paths', () => { + expect(normalizeMcpOAuthReturnTo('/settings/accountAskAgent?status=connected')).toBe('/settings/accountAskAgent?status=connected'); + }); + + test('rejects external and unrelated return paths', () => { + expect(normalizeMcpOAuthReturnTo('https://evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('//evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('/settings')).toBeUndefined(); + }); + + test('encodes and decodes return paths inside OAuth state', () => { + const state = createMcpOAuthState('nonce-1', '/chat'); + + expect(state).not.toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState(state)).toBe('/chat'); + }); + + test('leaves state unchanged when no valid return path exists', () => { + expect(createMcpOAuthState('nonce-1')).toBe('nonce-1'); + expect(createMcpOAuthState('nonce-1', '/settings')).toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState('nonce-1')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/features/mcp/mcpOAuthReturnTo.ts b/packages/web/src/features/mcp/mcpOAuthReturnTo.ts new file mode 100644 index 000000000..8127abdbc --- /dev/null +++ b/packages/web/src/features/mcp/mcpOAuthReturnTo.ts @@ -0,0 +1,63 @@ +const MCP_OAUTH_STATE_PREFIX = 'sourcebot_mcp.'; +const MCP_OAUTH_STATE_BASE_URL = 'https://sourcebot.invalid'; + +function isAllowedMcpOAuthReturnPath(pathname: string): boolean { + return pathname === '/chat' || pathname.startsWith('/chat/') || pathname === '/settings/accountAskAgent'; +} + +export function normalizeMcpOAuthReturnTo(returnTo: unknown): string | undefined { + if (typeof returnTo !== 'string') { + return undefined; + } + + const trimmedReturnTo = returnTo.trim(); + if (!trimmedReturnTo || !trimmedReturnTo.startsWith('/') || trimmedReturnTo.startsWith('//') || trimmedReturnTo.includes('\\')) { + return undefined; + } + + try { + const url = new URL(trimmedReturnTo, MCP_OAUTH_STATE_BASE_URL); + if (url.origin !== MCP_OAUTH_STATE_BASE_URL || !isAllowedMcpOAuthReturnPath(url.pathname)) { + return undefined; + } + + return `${url.pathname}${url.search}`; + } catch { + return undefined; + } +} + +export function createMcpOAuthState(nonce: string, returnTo?: string): string { + const normalizedReturnTo = normalizeMcpOAuthReturnTo(returnTo); + if (!normalizedReturnTo) { + return nonce; + } + + const encoded = Buffer.from(JSON.stringify({ + nonce, + returnTo: normalizedReturnTo, + })).toString('base64url'); + return `${MCP_OAUTH_STATE_PREFIX}${encoded}`; +} + +export function getMcpOAuthReturnToFromState(state: string | null | undefined): string | undefined { + if (!state?.startsWith(MCP_OAUTH_STATE_PREFIX)) { + return undefined; + } + + try { + const encoded = state.slice(MCP_OAUTH_STATE_PREFIX.length); + const payload = JSON.parse(Buffer.from(encoded, 'base64url').toString('utf8')) as unknown; + if ( + typeof payload === 'object' && + payload !== null && + 'returnTo' in payload + ) { + return normalizeMcpOAuthReturnTo(payload.returnTo); + } + } catch { + return undefined; + } + + return undefined; +} diff --git a/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts b/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts new file mode 100644 index 000000000..998181c7b --- /dev/null +++ b/packages/web/src/features/mcp/prismaOAuthClientProvider.test.ts @@ -0,0 +1,268 @@ +import { describe, expect, test, vi, beforeEach } from 'vitest'; +import { McpServerClientInfoSource } from '@sourcebot/db'; + +const mocks = vi.hoisted(() => ({ + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: { + mcpServer: {}, + userMcpServer: {}, + }, +})); +vi.mock('@sourcebot/shared', () => ({ + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), + createLogger: () => mocks.logger, +})); + +const { + PrismaOAuthClientProvider, + clearMcpServerClientCredentialsForObservedClient, +} = await import('./prismaOAuthClientProvider'); + +function createPrismaMock() { + return { + mcpServer: { + findFirst: vi.fn(), + updateMany: vi.fn(), + }, + userMcpServer: { + findUnique: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + }; +} + +function createProvider(prisma = createPrismaMock(), allowClientRegistration = false) { + return new PrismaOAuthClientProvider({ + prisma: prisma as never, + clientInvalidationPrisma: prisma as never, + serverId: 'server-1', + orgId: 1, + userId: 'user-1', + callbackUrl: 'https://sourcebot.example.com/api/ee/askmcp/callback', + allowClientRegistration, + }); +} + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('PrismaOAuthClientProvider modes', () => { + test('connect-mode provider exposes saveClientInformation', () => { + const provider = createProvider(createPrismaMock(), true); + + expect('saveClientInformation' in provider).toBe(true); + expect(provider.saveClientInformation).toEqual(expect.any(Function)); + }); + + test('runtime and callback providers omit saveClientInformation', () => { + const provider = createProvider(); + + expect('saveClientInformation' in provider).toBe(false); + expect(provider.saveClientInformation).toBeUndefined(); + }); +}); + +describe('clearMcpServerClientCredentialsForObservedClient', () => { + test('matching observed clientInfo clears org clientInfo and all server tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 1 }); + prisma.userMcpServer.updateMany.mockResolvedValue({ count: 2 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'encrypted-client-info', + }); + + expect(didClear).toBe(true); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted-client-info', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).toHaveBeenCalledWith({ + where: { + serverId: 'server-1', + server: { orgId: 1 }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + }); + + test('stale observed clientInfo clears neither org clientInfo nor tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'stale-client-info', + }); + + expect(didClear).toBe(false); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledOnce(); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + }); +}); + +describe('PrismaOAuthClientProvider PKCE verifier storage', () => { + test('saveCodeVerifier encrypts the verifier before persisting it', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.saveCodeVerifier('verifier-secret'); + + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: 'encrypted:verifier-secret', + }, + }); + }); + + test('codeVerifier decrypts the stored verifier', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.findUnique.mockResolvedValue({ + codeVerifier: 'encrypted:verifier-secret', + tokens: null, + state: null, + }); + const provider = createProvider(prisma); + + await expect(provider.codeVerifier()).resolves.toBe('verifier-secret'); + expect(mocks.logger.warn).not.toHaveBeenCalled(); + }); + + test('codeVerifier still accepts plaintext verifier values during migration and logs the fallback', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.findUnique.mockResolvedValue({ + codeVerifier: 'plaintext-verifier', + tokens: null, + state: null, + }); + const provider = createProvider(prisma); + + await expect(provider.codeVerifier()).resolves.toBe('plaintext-verifier'); + expect(mocks.logger.warn).toHaveBeenCalledWith( + 'MCP OAuth code verifier was read without decryption.', + { + serverId: 'server-1', + orgId: 1, + userId: 'user-1', + }, + ); + }); +}); + +describe('PrismaOAuthClientProvider authorization redirect', () => { + test('overwrites existing prompt values with consent', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.redirectToAuthorization(new URL('https://oauth.example.com/authorize?prompt=none&client_id=client-1')); + + expect(provider.authorizationUrl).toBe('https://oauth.example.com/authorize?prompt=consent&client_id=client-1'); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + }); +}); + +describe('PrismaOAuthClientProvider static client information', () => { + test('clientInformation returns static OAuth client credentials', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + const provider = createProvider(prisma); + + await expect(provider.clientInformation()).resolves.toEqual({ + client_id: 'client-id', + client_secret: 'client-secret', + }); + }); + + test('invalidate all preserves static client information and clears only the current user tokens and verifier', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.clientInformation(); + await provider.invalidateCredentials('all'); + + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + }); +}); diff --git a/packages/web/src/features/mcp/prismaOAuthClientProvider.ts b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts new file mode 100644 index 000000000..263a0c66b --- /dev/null +++ b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts @@ -0,0 +1,305 @@ +import 'server-only'; +import type { + OAuthClientProvider, + OAuthClientInformation, + OAuthClientMetadata, + OAuthTokens, +} from '@ai-sdk/mcp'; +import { McpServerClientInfoSource, type PrismaClient } from '@sourcebot/db'; +import { encryptOAuthToken, decryptOAuthToken, createLogger } from '@sourcebot/shared'; +import { __unsafePrisma } from '@/prisma'; +import { createMcpOAuthState } from './mcpOAuthReturnTo'; + +type McpOAuthPrismaClient = Pick; +const logger = createLogger('mcp-oauth-client-provider'); + +interface PrismaOAuthClientProviderOptions { + prisma: McpOAuthPrismaClient; + serverId: string; + orgId: number; + userId: string; + callbackUrl: string; + callbackReturnTo?: string; + allowClientRegistration?: boolean; + clientInvalidationPrisma?: McpOAuthPrismaClient; +} + +export interface ClearMcpServerClientCredentialsOptions { + prisma?: McpOAuthPrismaClient; + serverId: string; + orgId: number; + observedClientInfo: string | undefined; +} + +export async function clearMcpServerClientCredentialsForObservedClient({ + prisma = __unsafePrisma, + serverId, + orgId, + observedClientInfo, +}: ClearMcpServerClientCredentialsOptions): Promise { + if (!observedClientInfo) { + return false; + } + + const result = await prisma.mcpServer.updateMany({ + where: { + id: serverId, + orgId, + clientInfo: observedClientInfo, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + + if (result.count === 0) { + return false; + } + + await prisma.userMcpServer.updateMany({ + where: { + serverId, + server: { orgId }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + + return true; +} + +/** + * Prisma-backed OAuthClientProvider for connecting to external MCP servers. + * + * Stores dynamic client registration on McpServer (per-org), and per-user + * tokens + ephemeral PKCE state on UserMcpServer. + */ +export class PrismaOAuthClientProvider implements OAuthClientProvider { + private readonly prisma: McpOAuthPrismaClient; + private readonly clientInvalidationPrisma: McpOAuthPrismaClient; + private readonly serverId: string; + private readonly orgId: number; + private readonly userId: string; + private readonly callbackUrl: string; + private readonly callbackReturnTo: string | undefined; + private observedClientInfo: string | undefined; + private observedClientInfoSource: McpServerClientInfoSource | undefined; + + /** Populated by redirectToAuthorization — read after auth() returns 'REDIRECT'. */ + public authorizationUrl: string | undefined; + + /** Only present in connect mode. If absent, the SDK cannot perform DCR. */ + declare saveClientInformation?: (info: OAuthClientInformation) => Promise; + + constructor({ + prisma, + serverId, + orgId, + userId, + callbackUrl, + callbackReturnTo, + allowClientRegistration = false, + clientInvalidationPrisma = __unsafePrisma, + }: PrismaOAuthClientProviderOptions) { + this.prisma = prisma; + this.clientInvalidationPrisma = clientInvalidationPrisma; + this.serverId = serverId; + this.orgId = orgId; + this.userId = userId; + this.callbackUrl = callbackUrl; + this.callbackReturnTo = callbackReturnTo; + + if (allowClientRegistration) { + this.saveClientInformation = async (info: OAuthClientInformation) => { + const encrypted = encryptOAuthToken(JSON.stringify(info)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth client information'); + } + + const result = await this.prisma.mcpServer.updateMany({ + where: { id: this.serverId, orgId: this.orgId }, + data: { + clientInfo: encrypted, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + if (result.count === 0) { + throw new Error('MCP server not found'); + } + + this.observedClientInfo = encrypted; + this.observedClientInfoSource = McpServerClientInfoSource.DYNAMIC; + }; + } + } + + get redirectUrl(): string | URL { + return this.callbackUrl; + } + + get clientMetadata(): OAuthClientMetadata { + return { + redirect_uris: [this.callbackUrl], + client_name: 'Sourcebot', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + } + + async clientInformation(): Promise { + const server = await this.prisma.mcpServer.findFirst({ + where: { id: this.serverId, orgId: this.orgId }, + select: { + clientInfo: true, + clientInfoSource: true, + }, + }); + if (!server?.clientInfo) { + this.observedClientInfo = undefined; + this.observedClientInfoSource = undefined; + return undefined; + } + + this.observedClientInfo = server.clientInfo; + this.observedClientInfoSource = server.clientInfoSource; + const decrypted = decryptOAuthToken(server.clientInfo); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async tokens(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.tokens) { + return undefined; + } + + const decrypted = decryptOAuthToken(userServer.tokens); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async saveTokens(tokens: OAuthTokens): Promise { + const encrypted = encryptOAuthToken(JSON.stringify(tokens)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth tokens'); + } + + const tokensExpiresAt = tokens.expires_in + ? new Date(Date.now() + tokens.expires_in * 1000) + : null; + await this.updateUserServer({ tokens: encrypted, tokensExpiresAt }); + } + + async codeVerifier(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.codeVerifier) { + throw new Error('No code verifier found'); + } + + const decrypted = decryptOAuthToken(userServer.codeVerifier); + if (!decrypted) { + throw new Error('Failed to decrypt OAuth code verifier'); + } + + if (decrypted === userServer.codeVerifier) { + logger.warn('MCP OAuth code verifier was read without decryption.', { + serverId: this.serverId, + orgId: this.orgId, + userId: this.userId, + }); + } + + return decrypted; + } + + async saveCodeVerifier(codeVerifier: string): Promise { + const encrypted = encryptOAuthToken(codeVerifier); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth code verifier'); + } + + await this.updateUserServer({ codeVerifier: encrypted }); + } + + async state(): Promise { + return createMcpOAuthState(crypto.randomUUID(), this.callbackReturnTo); + } + + async saveState(state: string): Promise { + await this.updateUserServer({ state }); + } + + async storedState(): Promise { + const userServer = await this.getUserServer(); + return userServer?.state ?? undefined; + } + + async redirectToAuthorization(url: URL): Promise { + // Force the OAuth provider to show a consent/login screen on every authorization. + // This prevents a stolen-session attack where an attacker signs into Sourcebot on + // a victim's machine and silently obtains the victim's provider tokens via an + // existing browser session. + url.searchParams.set('prompt', 'consent'); + + // Clear stale tokens before starting a new authorization flow so the UI reflects + // that the user needs to complete OAuth again. + await this.invalidateCredentials('tokens'); + + this.authorizationUrl = url.toString(); + } + + async invalidateCredentials( + scope: 'all' | 'client' | 'tokens' | 'verifier', + ): Promise { + if (scope === 'all' || scope === 'client') { + const didClearDynamicClient = await clearMcpServerClientCredentialsForObservedClient({ + prisma: this.clientInvalidationPrisma, + serverId: this.serverId, + orgId: this.orgId, + observedClientInfo: this.observedClientInfo, + }); + if ( + scope === 'all' && + !didClearDynamicClient && + this.observedClientInfoSource === McpServerClientInfoSource.STATIC + ) { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + } + + if (scope === 'tokens') { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + + if (scope === 'all' || scope === 'verifier') { + await this.updateUserServer({ codeVerifier: null, state: null }); + } + } + + private async getUserServer() { + return this.prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + select: { + tokens: true, + codeVerifier: true, + state: true, + }, + }); + } + + private async updateUserServer(data: { + tokens?: string | null; + tokensExpiresAt?: Date | null; + codeVerifier?: string | null; + state?: string | null; + }) { + await this.prisma.userMcpServer.update({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + data, + }); + } +} diff --git a/packages/web/src/features/mcp/prismaScope.test.ts b/packages/web/src/features/mcp/prismaScope.test.ts new file mode 100644 index 000000000..4b86264db --- /dev/null +++ b/packages/web/src/features/mcp/prismaScope.test.ts @@ -0,0 +1,443 @@ +import { describe, expect, test, vi } from 'vitest'; +import type { UserWithAccounts } from '@sourcebot/db'; +import { getMcpPrismaQueryExtension, scopeUserMcpServerWhere } from './prismaScope'; + +const user = { + id: 'user-1', + name: 'Test User', + email: 'test@example.com', + hashedPassword: null, + emailVerified: null, + image: null, + sessionVersion: 0, + createdAt: new Date('2026-01-01T00:00:00Z'), + updatedAt: new Date('2026-01-01T00:00:00Z'), + accounts: [], +} satisfies UserWithAccounts; + +const callQuery = vi.fn(async (args: unknown) => args); + +const resetQuery = () => { + callQuery.mockClear(); + return callQuery; +}; + +const callAllOperations = ( + model: { + $allOperations: (params: { + operation: string; + args: unknown; + query: (args: unknown) => Promise; + }) => Promise; + }, + operation: string, + args: unknown, + query = resetQuery(), +) => model.$allOperations({ operation, args, query }); + +describe('scopeUserMcpServerWhere', () => { + test('merges existing filters with the authenticated user id', () => { + expect(scopeUserMcpServerWhere({ tokens: { not: null } }, user)).toEqual({ + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }); + }); + + test('fails closed for anonymous users', () => { + expect(scopeUserMcpServerWhere(undefined, undefined)).toEqual({ + AND: [ + { userId: '__sourcebot_anonymous_user__' }, + { userId: '__sourcebot_no_authenticated_user__' }, + ], + }); + }); +}); + +describe('getMcpPrismaQueryExtension', () => { + test('scopes list-style UserMcpServer reads', async () => { + const extension = getMcpPrismaQueryExtension(user); + const result = await extension.userMcpServer.findMany({ + args: { where: { tokens: { not: null } } }, + query: resetQuery(), + }); + + expect(result).toEqual({ + where: { + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }, + }); + }); + + test('returns null for anonymous or mismatched findUnique queries', async () => { + const anonymousExtension = getMcpPrismaQueryExtension(); + const mismatchedExtension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(anonymousExtension.userMcpServer.findUnique({ + args: { where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } } }, + query, + })).resolves.toBeNull(); + await expect(mismatchedExtension.userMcpServer.findUnique({ + args: { where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } } }, + query, + })).resolves.toBeNull(); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows matching findUnique queries through', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } } }; + + await expect(extension.userMcpServer.findUnique({ + args, + query: resetQuery(), + })).resolves.toBe(args); + }); + + test('rejects creates for anonymous or mismatched users', async () => { + const anonymousExtension = getMcpPrismaQueryExtension(); + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(anonymousExtension.userMcpServer.create({ + args: { data: { userId: 'user-1', serverId: 'server-1' } }, + query, + })).rejects.toThrow('requires an authenticated user'); + await expect(extension.userMcpServer.create({ + args: { data: { userId: 'user-2', serverId: 'server-1' } }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows checked creates that connect the authenticated user', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { + data: { + user: { connect: { id: 'user-1' } }, + server: { connect: { id: 'server-1' } }, + }, + }; + + await expect(extension.userMcpServer.create({ + args, + query: resetQuery(), + })).resolves.toBe(args); + }); + + test('rejects checked creates that do not connect the authenticated user', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.create({ + args: { + data: { + user: { connect: { id: 'user-2' } }, + server: { connect: { id: 'server-1' } }, + }, + }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + await expect(extension.userMcpServer.create({ + args: { + data: { + user: { create: { id: 'user-1', email: 'test@example.com' } }, + server: { connect: { id: 'server-1' } }, + }, + }, + query, + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects mismatched update/delete composite keys', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } }, + data: { state: null }, + }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows for another user'); + await expect(extension.userMcpServer.delete({ + args: { where: { userId_serverId: { userId: 'user-2', serverId: 'server-1' } } }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows for another user'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects attempts to mutate UserMcpServer ownership', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + data: { userId: 'user-2' }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + await expect(extension.userMcpServer.update({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + data: { server: { connect: { id: 'server-2' } } }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + await expect(extension.userMcpServer.upsert({ + args: { + where: { userId_serverId: { userId: 'user-1', serverId: 'server-1' } }, + create: { userId: 'user-1', serverId: 'server-1' }, + update: { user: { connect: { id: 'user-2' } } }, + }, + query: resetQuery(), + })).rejects.toThrow('cannot change UserMcpServer identity'); + }); + + test('scopes updateMany and deleteMany', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.updateMany({ + args: { where: { tokens: { not: null } }, data: { state: null } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { tokens: { not: null } }, + { userId: 'user-1' }, + ], + }, + data: { state: null }, + }); + await expect(extension.userMcpServer.deleteMany({ + args: { where: { serverId: 'server-1' } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { serverId: 'server-1' }, + { userId: 'user-1' }, + ], + }, + }); + }); + + test('scopes returning bulk UserMcpServer operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + + await expect(extension.userMcpServer.createManyAndReturn({ + args: { data: { userId: 'user-2', serverId: 'server-1' } }, + query: resetQuery(), + })).rejects.toThrow('must create UserMcpServer rows for the authenticated user'); + await expect(extension.userMcpServer.updateManyAndReturn({ + args: { where: { serverId: 'server-1' }, data: { state: null } }, + query: resetQuery(), + })).resolves.toEqual({ + where: { + AND: [ + { serverId: 'server-1' }, + { userId: 'user-1' }, + ], + }, + data: { state: null }, + }); + }); + + test('rejects nested UserMcpServer relation access through direct UserMcpServer queries', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(extension.userMcpServer.findMany({ + args: { + include: { + server: { + include: { + userMcpServers: true, + }, + }, + }, + }, + query, + })).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects nested UserMcpServer writes through McpServer operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'update', + { + where: { id: 'server-1' }, + data: { userMcpServers: { create: { userId: 'user-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects nested UserMcpServer reads and writes through parent models', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'findUnique', + { + where: { id: 'server-1' }, + include: { userMcpServers: true }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + await expect(callAllOperations( + extension.user, + 'findMany', + { + where: { userMcpServers: { some: { serverId: 'server-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + await expect(callAllOperations( + extension.user, + 'update', + { + where: { id: 'user-1' }, + data: { userMcpServers: { create: { serverId: 'server-1' } } }, + }, + query, + )).rejects.toThrow('cannot access UserMcpServer rows through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects transitive MCP relation access through Org and UserToOrg operations', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.org, + 'findUnique', + { + where: { id: 1 }, + include: { + mcpServers: { + include: { + userMcpServers: true, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + await expect(callAllOperations( + extension.org, + 'update', + { + where: { id: 1 }, + data: { + mcpServers: { + create: { + name: 'Linear', + sanitizedName: 'linear', + serverUrl: 'https://mcp.linear.app/mcp', + userMcpServers: { + create: { userId: 'user-1' }, + }, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + await expect(callAllOperations( + extension.userToOrg, + 'findMany', + { + include: { + org: { + include: { + mcpServers: { + include: { + userMcpServers: true, + }, + }, + }, + }, + }, + }, + query, + )).rejects.toThrow('cannot access MCP server relations through a parent relation'); + + expect(query).not.toHaveBeenCalled(); + }); + + test('allows JSON metadata payloads with relation-like keys', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { + where: { id: 1 }, + data: { + metadata: { + mcpServers: 'display-state', + userMcpServers: { collapsed: true }, + }, + }, + }; + + await expect(callAllOperations(extension.org, 'update', args)).resolves.toBe(args); + }); + + test('passes safe parent-model operations through the compact hooks', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { orgId: 1 } }; + + await expect(callAllOperations(extension.userToOrg, 'findMany', args)).resolves.toBe(args); + }); + + test('allows single user deletes but blocks bulk user deletes', async () => { + const extension = getMcpPrismaQueryExtension(user); + const args = { where: { id: 'user-2' } }; + const query = resetQuery(); + + await expect(callAllOperations(extension.user, 'delete', args, query)).resolves.toBe(args); + expect(query).toHaveBeenCalledTimes(1); + query.mockClear(); + + await expect(callAllOperations(extension.user, 'deleteMany', { where: {} }, query)) + .rejects.toThrow('user.deleteMany cannot delete users through a user-scoped client'); + expect(query).not.toHaveBeenCalled(); + }); + + test('rejects shared McpServer deletes through the scoped client', async () => { + const extension = getMcpPrismaQueryExtension(user); + const query = resetQuery(); + + await expect(callAllOperations( + extension.mcpServer, + 'delete', + { where: { id: 'server-1' } }, + query, + )).rejects.toThrow('cannot delete shared McpServer rows through a user-scoped client'); + await expect(callAllOperations( + extension.mcpServer, + 'deleteMany', + { where: { orgId: 1 } }, + query, + )).rejects.toThrow('cannot delete shared McpServer rows through a user-scoped client'); + + expect(query).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/web/src/features/mcp/prismaScope.ts b/packages/web/src/features/mcp/prismaScope.ts new file mode 100644 index 000000000..9e3089f24 --- /dev/null +++ b/packages/web/src/features/mcp/prismaScope.ts @@ -0,0 +1,366 @@ +import { Prisma, UserWithAccounts } from '@sourcebot/db'; + +type QueryHookParams = { + args: TArgs; + query: (args: TArgs) => Promise; +}; + +type AllOperationsHookParams = { + operation: string; + args: unknown; + query: (args: unknown) => Promise; +}; + +type UserMcpServerWhereArgs = { + where?: Prisma.UserMcpServerWhereInput; +}; + +type UserMcpServerWhereUniqueArgs = { + where: Prisma.UserMcpServerWhereUniqueInput; +}; + +type UserMcpServerCreateArgs = { + data: unknown; +}; + +type UserMcpServerUpdateArgs = UserMcpServerWhereUniqueArgs & { + data: unknown; +}; + +type UserMcpServerUpsertArgs = UserMcpServerWhereUniqueArgs & { + create: unknown; + update: unknown; +}; + +// Deliberately impossible filter — AND-ing two different userId values guarantees zero rows. +// Used as the fallback when no user is authenticated, so anonymous queries see nothing. +// Prisma doesn't expose a "match nothing" primitive, so this is the standard workaround. +const anonymousUserScope: Prisma.UserMcpServerWhereInput = { + AND: [ + { userId: '__sourcebot_anonymous_user__' }, + { userId: '__sourcebot_no_authenticated_user__' }, + ], +}; + +const isRecord = (value: unknown): value is Record => + typeof value === 'object' && value !== null && !Array.isArray(value); + +const userScopeWhere = (user?: UserWithAccounts): Prisma.UserMcpServerWhereInput => + user ? { userId: user.id } : anonymousUserScope; + +export const scopeUserMcpServerWhere = ( + where: Prisma.UserMcpServerWhereInput | undefined, + user?: UserWithAccounts, +): Prisma.UserMcpServerWhereInput => { + const scope = userScopeWhere(user); + return where ? { AND: [where, scope] } : scope; +}; + +const scopeUserMcpServerReadArgs = ( + args: TArgs, + user?: UserWithAccounts, +): TArgs => ({ + ...args, + where: scopeUserMcpServerWhere(args.where, user), +}); + +const requireAuthenticatedUser = ( + user: UserWithAccounts | undefined, + operation: string, +): UserWithAccounts => { + if (!user) { + throw new Error(`${operation} requires an authenticated user.`); + } + return user; +}; + +const uniqueWhereUserId = (where: Prisma.UserMcpServerWhereUniqueInput): string | undefined => { + const compositeKey = where.userId_serverId; + return isRecord(compositeKey) && typeof compositeKey.userId === 'string' + ? compositeKey.userId + : undefined; +}; + +export const isUserMcpServerUniqueWhereForUser = ( + where: Prisma.UserMcpServerWhereUniqueInput, + user?: UserWithAccounts, +) => !!user && uniqueWhereUserId(where) === user.id; + +const assertUserMcpServerUniqueWhereForUser = ( + where: Prisma.UserMcpServerWhereUniqueInput, + user: UserWithAccounts | undefined, + operation: string, +) => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + if (!isUserMcpServerUniqueWhereForUser(where, authenticatedUser)) { + throw new Error(`${operation} cannot access UserMcpServer rows for another user.`); + } +}; + +const assertNoIdentityMutation = (data: unknown, operation: string) => { + if (!isRecord(data)) { + return; + } + + if ('userId' in data || 'user' in data || 'serverId' in data || 'server' in data) { + throw new Error(`${operation} cannot change UserMcpServer identity.`); + } +}; + +// Extracts the userId from a Prisma relation connect object. +// Prisma's connect syntax for a relation looks like: { connect: { id: "some-id" } } +const connectedUserId = (userRelation: unknown): string | undefined => { + if (!isRecord(userRelation) || !('connect' in userRelation)) { + return undefined; + } + + const connect = userRelation.connect; + if (!isRecord(connect) || !('id' in connect) || typeof connect.id !== 'string') { + return undefined; + } + + return connect.id; +}; + +const createDataUserId = (row: unknown): string | undefined => { + if (!isRecord(row)) { + return undefined; + } + const scalarUserId = typeof row.userId === 'string' ? row.userId : undefined; + const relationUserId = row.user === undefined ? undefined : connectedUserId(row.user); + + if (row.user !== undefined && relationUserId === undefined) { + return undefined; + } + if (scalarUserId !== undefined && relationUserId !== undefined && scalarUserId !== relationUserId) { + return undefined; + } + + return relationUserId ?? scalarUserId; +}; + +const assertCreateDataForUser = ( + data: unknown, + user: UserWithAccounts | undefined, + operation: string, +) => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + + const rows = Array.isArray(data) ? data : [data]; + for (const row of rows) { + if (createDataUserId(row) !== authenticatedUser.id) { + throw new Error(`${operation} must create UserMcpServer rows for the authenticated user.`); + } + } +}; + +const scopeUserMcpServerWriteManyArgs = ( + args: TArgs, + user: UserWithAccounts | undefined, + operation: string, +): TArgs => { + const authenticatedUser = requireAuthenticatedUser(user, operation); + return scopeUserMcpServerReadArgs(args, authenticatedUser); +}; + +const PRISMA_SELECTION_KEYS = new Set(['include', 'select']); +const PRISMA_STRUCTURAL_KEYS = new Set([ + ...PRISMA_SELECTION_KEYS, + 'where', + 'orderBy', + 'data', + 'create', + 'connectOrCreate', + 'update', + 'updateMany', + 'upsert', + 'delete', + 'deleteMany', + 'AND', + 'OR', + 'NOT', + 'some', + 'none', + 'every', + 'is', + 'isNot', +]); +const MCP_RELATION_BRIDGE_KEYS = new Set([ + 'user', + 'server', + 'org', + 'orgs', + 'members', +]); + +const containsPrismaRelationAccess = ( + value: unknown, + relationNames: string[], + isSelectionObject = false, +): boolean => { + if (Array.isArray(value)) { + return value.some((item) => containsPrismaRelationAccess(item, relationNames, isSelectionObject)); + } + if (!isRecord(value)) { + return false; + } + if (relationNames.some((relationName) => relationName in value)) { + return true; + } + + return Object.entries(value).some(([key, nestedValue]) => { + if (PRISMA_SELECTION_KEYS.has(key)) { + return containsPrismaRelationAccess(nestedValue, relationNames, true); + } + + if (isSelectionObject || PRISMA_STRUCTURAL_KEYS.has(key) || MCP_RELATION_BRIDGE_KEYS.has(key)) { + return containsPrismaRelationAccess(nestedValue, relationNames); + } + + return false; + }); +}; + +const assertNoUserMcpServerRelationAccess = (args: unknown, operation: string) => { + if (containsPrismaRelationAccess(args, ['userMcpServers'])) { + throw new Error(`${operation} cannot access UserMcpServer rows through a parent relation.`); + } +}; + +const assertNoMcpServerRelationAccess = (args: unknown, operation: string) => { + if (containsPrismaRelationAccess(args, ['mcpServers', 'userMcpServers'])) { + throw new Error(`${operation} cannot access MCP server relations through a parent relation.`); + } +}; + +const rejectSharedMcpServerDelete = (operation: string) => { + throw new Error(`${operation} cannot delete shared McpServer rows through a user-scoped client.`); +}; + +const rejectUserDeleteMany = () => { + throw new Error('user.deleteMany cannot delete users through a user-scoped client.'); +}; + +const guardMcpParentOperation = ( + modelName: string, + guard: (args: unknown, operation: string) => void, +) => async ({ operation, args, query }: AllOperationsHookParams) => { + guard(args, `${modelName}.${operation}`); + return query(args); +}; + +export const getMcpPrismaQueryExtension = (user?: UserWithAccounts) => ({ + userMcpServer: { + async findMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findMany'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findFirst({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findFirst'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findFirstOrThrow({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findFirstOrThrow'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async findUnique({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findUnique'); + // Preserve Prisma's nullable "not found" semantics for scoped reads. Callers that + // need a hard failure should use findUniqueOrThrow; write paths throw on mismatch. + return isUserMcpServerUniqueWhereForUser(args.where, user) ? query(args) : null; + }, + async findUniqueOrThrow({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.findUniqueOrThrow'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.findUniqueOrThrow'); + return query(args); + }, + async count({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.count'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async aggregate({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.aggregate'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async groupBy({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.groupBy'); + return query(scopeUserMcpServerReadArgs(args, user)); + }, + async create({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.create'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.create'); + return query(args); + }, + async createMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.createMany'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.createMany'); + return query(args); + }, + async createManyAndReturn({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.createManyAndReturn'); + assertCreateDataForUser((args as UserMcpServerCreateArgs).data, user, 'userMcpServer.createManyAndReturn'); + return query(args); + }, + async update({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.update'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.update'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.update'); + return query(args); + }, + async updateMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.updateMany'); + requireAuthenticatedUser(user, 'userMcpServer.updateMany'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.updateMany'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.updateMany')); + }, + async updateManyAndReturn({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.updateManyAndReturn'); + requireAuthenticatedUser(user, 'userMcpServer.updateManyAndReturn'); + assertNoIdentityMutation((args as UserMcpServerUpdateArgs).data, 'userMcpServer.updateManyAndReturn'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.updateManyAndReturn')); + }, + async delete({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.delete'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.delete'); + return query(args); + }, + async deleteMany({ args, query }: QueryHookParams) { + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.deleteMany'); + return query(scopeUserMcpServerWriteManyArgs(args, user, 'userMcpServer.deleteMany')); + }, + async upsert({ args, query }: QueryHookParams) { + const upsertArgs = args as UserMcpServerUpsertArgs; + assertNoUserMcpServerRelationAccess(args, 'userMcpServer.upsert'); + assertUserMcpServerUniqueWhereForUser(args.where, user, 'userMcpServer.upsert'); + assertCreateDataForUser(upsertArgs.create, user, 'userMcpServer.upsert'); + assertNoIdentityMutation(upsertArgs.update, 'userMcpServer.upsert'); + return query(args); + }, + }, + user: { + async $allOperations({ operation, args, query }: AllOperationsHookParams) { + if (operation === 'deleteMany') { + rejectUserDeleteMany(); + } + // The owner-only user deletion API intentionally deletes one user and relies on + // cascade to remove that user's rows. Bulk deletes stay blocked above. + assertNoUserMcpServerRelationAccess(args, `user.${operation}`); + return query(args); + }, + }, + mcpServer: { + async $allOperations({ operation, args, query }: AllOperationsHookParams) { + if (operation === 'delete' || operation === 'deleteMany') { + rejectSharedMcpServerDelete(`mcpServer.${operation}`); + } + assertNoUserMcpServerRelationAccess(args, `mcpServer.${operation}`); + return query(args); + }, + }, + org: { + $allOperations: guardMcpParentOperation('org', assertNoMcpServerRelationAccess), + }, + userToOrg: { + $allOperations: guardMcpParentOperation('userToOrg', assertNoMcpServerRelationAccess), + }, +}); diff --git a/packages/web/src/lib/errorCodes.ts b/packages/web/src/lib/errorCodes.ts index 714932c30..fdb09d67d 100644 --- a/packages/web/src/lib/errorCodes.ts +++ b/packages/web/src/lib/errorCodes.ts @@ -35,4 +35,6 @@ export enum ErrorCode { LAST_OWNER_CANNOT_BE_DEMOTED = 'LAST_OWNER_CANNOT_BE_DEMOTED', LAST_OWNER_CANNOT_BE_REMOVED = 'LAST_OWNER_CANNOT_BE_REMOVED', API_KEY_USAGE_DISABLED = 'API_KEY_USAGE_DISABLED', + MCP_SERVER_ALREADY_EXISTS = 'MCP_SERVER_ALREADY_EXISTS', + MCP_SERVER_NOT_FOUND = 'MCP_SERVER_NOT_FOUND', } diff --git a/packages/web/src/lib/posthogEvents.ts b/packages/web/src/lib/posthogEvents.ts index b63014002..4326036d0 100644 --- a/packages/web/src/lib/posthogEvents.ts +++ b/packages/web/src/lib/posthogEvents.ts @@ -7,6 +7,10 @@ export type UpsellSource = 'onboard' | 'license_settings'; +export type SourcebotWebClientSource = 'sourcebot-web-client'; +export type AskMcpAnalyticsSource = SourcebotWebClientSource | 'sourcebot-ask-agent'; +export type McpConnectorEntryPoint = 'chat' | 'account_settings' | 'workspace_settings' | 'unknown'; +export type McpConnectorAuthMode = 'dynamic' | 'static'; export type PosthogEventMap = { search_finished: { @@ -178,6 +182,10 @@ export type PosthogEventMap = { messageCount: number, selectedReposCount: number, source?: string, + hasAskMcpServersAvailable: boolean, + askMcpConnectedServerCount: number, + askMcpEnabledServerCount: number, + askMcpDisabledServerCount: number, /** * @note this field will only be populated when * the EXPERIMENT_ASK_GH_ENABLED environment variable @@ -185,6 +193,79 @@ export type PosthogEventMap = { */ selectedRepos?: string[], }, + ask_mcp_turn_completed: { + chatId: string, + source?: SourcebotWebClientSource, + traceId?: string, + askMcpUsed: boolean, + askMcpToolCallCount: number, + askMcpToolSuccessCount: number, + askMcpToolFailureCount: number, + askMcpApprovalRequestedCount: number, + askMcpApprovalDeniedCount: number, + askMcpFailedServerCount: number, + durationMs: number, + }, + ask_mcp_tool_call_completed: { + chatId?: string, + traceId?: string, + source: AskMcpAnalyticsSource, + serverId: string, + serverName: string, + serverUrl: string, + toolName: string, + qualifiedToolName: string, + success: boolean, + durationMs: number, + failureReason?: string, + }, + ask_mcp_connector_added: { + source: SourcebotWebClientSource, + entryPoint: 'workspace_settings', + serverId: string, + serverName: string, + serverUrl: string, + sanitizedName: string, + authMode: McpConnectorAuthMode, + }, + ask_mcp_connector_connection_started: { + source: SourcebotWebClientSource, + entryPoint: McpConnectorEntryPoint, + serverId: string, + serverName: string, + serverUrl: string, + sanitizedName: string, + authMode: McpConnectorAuthMode, + }, + ask_mcp_connector_connection_completed: { + source: SourcebotWebClientSource, + entryPoint: McpConnectorEntryPoint, + serverId: string, + serverName: string, + serverUrl: string, + sanitizedName: string, + authMode: McpConnectorAuthMode, + alreadyAuthorized: boolean, + }, + ask_mcp_connector_connection_failed: { + source: SourcebotWebClientSource, + entryPoint: McpConnectorEntryPoint, + serverId?: string, + serverName?: string, + serverUrl?: string, + sanitizedName?: string, + authMode?: McpConnectorAuthMode, + failureReason: string, + }, + ask_mcp_connector_disconnected: { + source: SourcebotWebClientSource, + entryPoint: McpConnectorEntryPoint, + serverId: string, + serverName: string, + serverUrl: string, + sanitizedName: string, + authMode: McpConnectorAuthMode, + }, tool_used: { toolName: string, source: string, @@ -316,4 +397,4 @@ export type PosthogEventMap = { clientName: string, }, } -export type PosthogEvent = keyof PosthogEventMap; \ No newline at end of file +export type PosthogEvent = keyof PosthogEventMap; diff --git a/packages/web/src/middleware/withAuth.test.ts b/packages/web/src/middleware/withAuth.test.ts index 8a9978bff..4856483d3 100644 --- a/packages/web/src/middleware/withAuth.test.ts +++ b/packages/web/src/middleware/withAuth.test.ts @@ -6,6 +6,7 @@ import { MOCK_API_KEY, MOCK_OAUTH_TOKEN, MOCK_ORG, MOCK_USER_WITH_ACCOUNTS, pris import { OrgRole } from '@sourcebot/db'; import { ErrorCode } from '../lib/errorCodes'; import { StatusCodes } from 'http-status-codes'; +import { userScopedPrismaClientExtension } from '@/prisma'; const mocks = vi.hoisted(() => { return { @@ -80,6 +81,7 @@ const createMockSession = (overrides: Partial = {}): Session => ({ beforeEach(() => { vi.clearAllMocks(); + vi.mocked(userScopedPrismaClientExtension).mockReset(); mocks.auth.mockResolvedValue(null); mocks.headers.mockResolvedValue(new Headers()); mocks.hasEntitlement.mockReturnValue(false); @@ -474,6 +476,39 @@ describe('getAuthContext', () => { }); describe('withAuth', () => { + test('should pass the scoped prisma client from $extends to the callback', async () => { + const userId = 'test-user-id'; + const user = { + ...MOCK_USER_WITH_ACCOUNTS, + id: userId, + }; + const extension = { query: { userMcpServer: {} } }; + const scopedPrisma = { scoped: true }; + + prisma.user.findUnique.mockResolvedValue(user); + prisma.org.findUnique.mockResolvedValue({ + ...MOCK_ORG, + }); + prisma.userToOrg.findUnique.mockResolvedValue({ + joinedAt: new Date(), + userId, + orgId: MOCK_ORG.id, + role: OrgRole.MEMBER, + }); + vi.mocked(userScopedPrismaClientExtension).mockResolvedValue(extension as never); + prisma.$extends.mockReturnValue(scopedPrisma as never); + setMockSession(createMockSession({ user: { id: userId } })); + + const cb = vi.fn(); + await withAuth(cb); + + expect(userScopedPrismaClientExtension).toHaveBeenCalledWith(user); + expect(prisma.$extends).toHaveBeenCalledWith(extension); + expect(cb).toHaveBeenCalledWith(expect.objectContaining({ + prisma: scopedPrisma, + })); + }); + test('should call the callback with the auth context object if a valid session is present and the user is a member of the organization', async () => { const userId = 'test-user-id'; prisma.user.findUnique.mockResolvedValue({ diff --git a/packages/web/src/prisma.ts b/packages/web/src/prisma.ts index f863f5ef7..0496d8c9b 100644 --- a/packages/web/src/prisma.ts +++ b/packages/web/src/prisma.ts @@ -2,6 +2,7 @@ import 'server-only'; import { env, getDBConnectionString } from "@sourcebot/shared"; import { Prisma, PrismaClient, UserWithAccounts } from "@sourcebot/db"; import { hasEntitlement } from "@/lib/entitlements"; +import { getMcpPrismaQueryExtension } from "@/features/mcp/prismaScope"; // @see: https://authjs.dev/getting-started/adapters/prisma const globalForPrisma = globalThis as unknown as { prisma: PrismaClient } @@ -35,6 +36,7 @@ export const userScopedPrismaClientExtension = async (user?: UserWithAccounts) = (prisma) => { return prisma.$extends({ query: { + ...getMcpPrismaQueryExtension(user), ...(hasPermissionSyncing ? { repo: { async $allOperations({ args, query }) { diff --git a/yarn.lock b/yarn.lock index 2357fe36c..18e92f632 100644 --- a/yarn.lock +++ b/yarn.lock @@ -99,6 +99,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/mcp@npm:^2.0.0-beta.11": + version: 2.0.0-beta.11 + resolution: "@ai-sdk/mcp@npm:2.0.0-beta.11" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@ai-sdk/provider-utils": "npm:5.0.0-beta.7" + pkce-challenge: "npm:^5.0.0" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/efcc9b9f5f8b20b78b2d0ee6d83b34466b2ec456c3b40b5b8b10af226e7d3f6144f964d87a20c5fc54c24b21f3610cb75cc246c30833b99fb501438a206c9933 + languageName: node + linkType: hard + "@ai-sdk/mistral@npm:^3.0.30": version: 3.0.30 resolution: "@ai-sdk/mistral@npm:3.0.30" @@ -148,6 +161,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:5.0.0-beta.7": + version: 5.0.0-beta.7 + resolution: "@ai-sdk/provider-utils@npm:5.0.0-beta.7" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@standard-schema/spec": "npm:^1.1.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/440825f7b599da6a0bd830c905f9ba4f21defcf7068bc98154ea38158c1134b049cb2815047013668f48b679a23de1d3c19eb072a65115dc860070168104c99e + languageName: node + linkType: hard + "@ai-sdk/provider@npm:3.0.8": version: 3.0.8 resolution: "@ai-sdk/provider@npm:3.0.8" @@ -157,6 +183,15 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider@npm:4.0.0-beta.5": + version: 4.0.0-beta.5 + resolution: "@ai-sdk/provider@npm:4.0.0-beta.5" + dependencies: + json-schema: "npm:^0.4.0" + checksum: 10c0/886f5892268cc3425130c9b019a9eb1e2acdb5efd05d920b05d1ac1ab49603393d8e509e6e0a3c46dee533a411a51a2af2c6fa0a173b41130f5175a615add7fb + languageName: node + linkType: hard + "@ai-sdk/react@npm:^3.0.169": version: 3.0.169 resolution: "@ai-sdk/react@npm:3.0.169" @@ -9340,6 +9375,7 @@ __metadata: "@ai-sdk/deepseek": "npm:^2.0.29" "@ai-sdk/google": "npm:^3.0.64" "@ai-sdk/google-vertex": "npm:^4.0.111" + "@ai-sdk/mcp": "npm:^2.0.0-beta.11" "@ai-sdk/mistral": "npm:^3.0.30" "@ai-sdk/openai": "npm:^3.0.53" "@ai-sdk/openai-compatible": "npm:^2.0.41" @@ -9552,7 +9588,7 @@ __metadata: vitest: "npm:^4.1.4" vitest-mock-extended: "npm:^4.0.0" vscode-icons-js: "npm:^11.6.1" - zod: "npm:^3.25.74" + zod: "npm:^3.25.76" zod-to-json-schema: "npm:^3.24.5" languageName: unknown linkType: soft @@ -18957,13 +18993,20 @@ __metadata: languageName: node linkType: hard -"picomatch@npm:^4.0.2, picomatch@npm:^4.0.3, picomatch@npm:^4.0.4": +"picomatch@npm:^4.0.2, picomatch@npm:^4.0.4": version: 4.0.4 resolution: "picomatch@npm:4.0.4" checksum: 10c0/e2c6023372cc7b5764719a5ffb9da0f8e781212fa7ca4bd0562db929df8e117460f00dff3cb7509dacfc06b86de924b247f504d0ce1806a37fac4633081466b0 languageName: node linkType: hard +"picomatch@npm:^4.0.3": + version: 4.0.3 + resolution: "picomatch@npm:4.0.3" + checksum: 10c0/9582c951e95eebee5434f59e426cddd228a7b97a0161a375aed4be244bd3fe8e3a31b846808ea14ef2c8a2527a6eeab7b3946a67d5979e81694654f939473ae2 + languageName: node + linkType: hard + "picospinner@npm:^3.0.0": version: 3.0.0 resolution: "picospinner@npm:3.0.0" @@ -23537,7 +23580,7 @@ __metadata: languageName: node linkType: hard -"zod@npm:^3.25.0": +"zod@npm:^3.25.0, zod@npm:^3.25.76": version: 3.25.76 resolution: "zod@npm:3.25.76" checksum: 10c0/5718ec35e3c40b600316c5b4c5e4976f7fee68151bc8f8d90ec18a469be9571f072e1bbaace10f1e85cf8892ea12d90821b200e980ab46916a6166a4260a983c