diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index c517140b8a..d0648b7635 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -45,7 +45,7 @@ export class MCPConnectionFactory { /** Creates a new MCP connection with optional OAuth support */ static async create( basic: t.BasicConnectionOptions, - oauth?: t.OAuthConnectionOptions, + oauth?: t.OAuthConnectionOptions | t.UserConnectionContext, ): Promise { const factory = new this(basic, oauth); return factory.createConnection(); @@ -232,6 +232,17 @@ export class MCPConnectionFactory { let cleanupOAuthHandlers: (() => void) | null = null; if (this.useOAuth) { cleanupOAuthHandlers = this.handleOAuthEvents(connection); + } else { + const nonOAuthHandler = () => { + logger.info( + `${this.logPrefix} Server does not use OAuth — treating 401/403 as auth failure, not OAuth`, + ); + connection.emit('oauthFailed', new Error('Server does not use OAuth')); + }; + connection.on('oauthRequired', nonOAuthHandler); + cleanupOAuthHandlers = () => { + connection.removeListener('oauthRequired', nonOAuthHandler); + }; } try { diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 12227de39f..d064f65c01 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -18,7 +18,7 @@ import { preProcessGraphTokens } from '~/utils/graph'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; -import { isUserSourced } from './utils'; +import { isUserSourced, isOAuthServer } from './utils'; /** * Centralized manager for MCP server connections and tool execution. @@ -102,7 +102,7 @@ export class MCPManager extends UserConnectionManager { return { tools: null, oauthRequired: false, oauthUrl: null }; } - const useOAuth = Boolean(serverConfig.requiresOAuth || serverConfig.oauthMetadata); + const useOAuth = isOAuthServer(serverConfig); const registry = MCPServersRegistry.getInstance(); const useSSRFProtection = registry.shouldEnableSSRFProtection(); diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 760f84c75e..b31078983f 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -4,7 +4,7 @@ import type * as t from './types'; import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { isUserSourced } from './utils'; +import { isUserSourced, isOAuthServer } from './utils'; import { MCPConnection } from './connection'; import { mcpConfig } from './mcpConfig'; @@ -35,14 +35,7 @@ export abstract class UserConnectionManager { } /** Gets or creates a connection for a specific user, coalescing concurrent attempts */ - public async getUserConnection( - opts: { - serverName: string; - forceNew?: boolean; - /** Pre-resolved config for config-source servers not in YAML/DB */ - serverConfig?: t.ParsedServerConfig; - } & Omit, - ): Promise { + public async getUserConnection(opts: t.UserMCPConnectionOptions): Promise { const { serverName, forceNew, user } = opts; const userId = user?.id; if (!userId) { @@ -89,11 +82,7 @@ export abstract class UserConnectionManager { returnOnOAuth = false, connectionTimeout, serverConfig: providedConfig, - }: { - serverName: string; - forceNew?: boolean; - serverConfig?: t.ParsedServerConfig; - } & Omit, + }: t.UserMCPConnectionOptions, userId: string, ): Promise { if (await this.appConnections!.has(serverName)) { @@ -161,28 +150,38 @@ export abstract class UserConnectionManager { try { const registry = MCPServersRegistry.getInstance(); - connection = await MCPConnectionFactory.create( - { - serverConfig: config, - serverName: serverName, - dbSourced: isUserSourced(config), - useSSRFProtection: registry.shouldEnableSSRFProtection(), - allowedDomains: registry.getAllowedDomains(), - }, - { - useOAuth: true, - user: user, - customUserVars: customUserVars, - flowManager: flowManager, - tokenMethods: tokenMethods, - signal: signal, - oauthStart: oauthStart, - oauthEnd: oauthEnd, - returnOnOAuth: returnOnOAuth, - requestBody: requestBody, - connectionTimeout: connectionTimeout, - }, - ); + const basic: t.BasicConnectionOptions = { + serverConfig: config, + serverName: serverName, + dbSourced: isUserSourced(config), + useSSRFProtection: registry.shouldEnableSSRFProtection(), + allowedDomains: registry.getAllowedDomains(), + }; + + const useOAuth = isOAuthServer(config); + if (useOAuth && !flowManager) { + throw new McpError( + ErrorCode.InvalidRequest, + `[MCP][User: ${userId}] OAuth server "${serverName}" requires a flowManager`, + ); + } + const oauthOptions: t.OAuthConnectionOptions | t.UserConnectionContext = useOAuth + ? { + useOAuth: true as const, + user, + customUserVars, + flowManager: flowManager, + tokenMethods, + signal, + oauthStart, + oauthEnd, + returnOnOAuth, + requestBody, + connectionTimeout, + } + : { user, customUserVars, requestBody, connectionTimeout }; + + connection = await MCPConnectionFactory.create(basic, oauthOptions); if (!(await connection?.isConnected())) { throw new Error('Failed to establish connection after initialization attempt.'); diff --git a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts index d90ca5b345..0e10351559 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionFactory.test.ts @@ -94,6 +94,36 @@ describe('MCPConnectionFactory', () => { expect(mockConnectionInstance.connect).toHaveBeenCalled(); }); + it('should register fallback oauthRequired handler for non-OAuth connections', async () => { + const basicOptions = { + serverName: 'test-server', + serverConfig: mockServerConfig, + }; + + mockConnectionInstance.isConnected.mockResolvedValue(true); + + await MCPConnectionFactory.create(basicOptions); + + expect(mockConnectionInstance.on).toHaveBeenCalledWith('oauthRequired', expect.any(Function)); + + const onCall = (mockConnectionInstance.on as jest.Mock).mock.calls.find( + ([event]: [string]) => event === 'oauthRequired', + ); + + const handler = onCall![1] as () => void; + handler(); + + expect(mockConnectionInstance.emit).toHaveBeenCalledWith( + 'oauthFailed', + expect.objectContaining({ message: 'Server does not use OAuth' }), + ); + + expect(mockConnectionInstance.removeListener).toHaveBeenCalledWith( + 'oauthRequired', + expect.any(Function), + ); + }); + it('should create a connection with OAuth', async () => { const basicOptions = { serverName: 'test-server', diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index ba5b0b3b8e..094d03215e 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -925,6 +925,44 @@ describe('MCPManager', () => { ); }); + it('should use isOAuthServer in discoverServerTools', async () => { + const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser; + const mockFlowManager = { + createFlow: jest.fn(), + getFlowState: jest.fn(), + deleteFlow: jest.fn(), + }; + + mockAppConnections({ + get: jest.fn().mockResolvedValue(null), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'streamable-http', + url: 'http://private-mcp.svc:5446/mcp', + requiresOAuth: false, + }); + + (MCPConnectionFactory.discoverTools as jest.Mock).mockResolvedValue({ + tools: mockTools, + connection: null, + oauthRequired: false, + oauthUrl: null, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.discoverServerTools({ + serverName, + user: mockUser, + flowManager: mockFlowManager as unknown as t.ToolDiscoveryOptions['flowManager'], + }); + + expect(MCPConnectionFactory.discoverTools).toHaveBeenCalledWith( + expect.objectContaining({ serverName }), + expect.not.objectContaining({ useOAuth: true }), + ); + }); + it('should discover tools with OAuth when user and flowManager provided', async () => { const mockUser = { id: 'user123', email: 'test@example.com' } as unknown as IUser; const mockFlowManager = { @@ -966,4 +1004,89 @@ describe('MCPManager', () => { ); }); }); + + describe('getUserConnection - useOAuth derivation', () => { + const mockUser = { id: userId, email: 'test@example.com' } as unknown as IUser; + const mockFlowManager = { + createFlow: jest.fn(), + getFlowState: jest.fn(), + deleteFlow: jest.fn(), + }; + const mockConnection = { + isConnected: jest.fn().mockResolvedValue(true), + isStale: jest.fn().mockReturnValue(false), + disconnect: jest.fn(), + } as unknown as MCPConnection; + + it('should pass useOAuth: true for servers with requiresOAuth', async () => { + mockAppConnections({ + has: jest.fn().mockResolvedValue(false), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'sse', + url: 'https://oauth-mcp.example.com', + requiresOAuth: true, + }); + + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.getUserConnection({ + serverName, + user: mockUser, + flowManager: mockFlowManager as unknown as t.UserMCPConnectionOptions['flowManager'], + }); + + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ serverName }), + expect.objectContaining({ useOAuth: true }), + ); + }); + + it('should not pass useOAuth for servers with requiresOAuth: false', async () => { + mockAppConnections({ + has: jest.fn().mockResolvedValue(false), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'streamable-http', + url: 'http://private-mcp.svc:5446/mcp', + requiresOAuth: false, + }); + + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await manager.getUserConnection({ + serverName, + user: mockUser, + }); + + expect(MCPConnectionFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ serverName }), + expect.not.objectContaining({ useOAuth: true }), + ); + }); + + it('should throw when OAuth server lacks flowManager', async () => { + mockAppConnections({ + has: jest.fn().mockResolvedValue(false), + }); + + (mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue({ + type: 'sse', + url: 'https://oauth-mcp.example.com', + requiresOAuth: true, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + await expect( + manager.getUserConnection({ + serverName, + user: mockUser, + }), + ).rejects.toThrow('requires a flowManager'); + }); + }); }); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 32c2787165..3c1570e2a9 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -202,6 +202,19 @@ export interface OAuthConnectionOptions extends UserConnectionContext { returnOnOAuth?: boolean; } +/** Options accepted by UserConnectionManager.getUserConnection — OAuth fields are optional. */ +export interface UserMCPConnectionOptions extends UserConnectionContext { + serverName: string; + forceNew?: boolean; + serverConfig?: ParsedServerConfig; + flowManager?: FlowStateManager; + tokenMethods?: TokenMethods; + signal?: AbortSignal; + oauthStart?: (authURL: string) => Promise; + oauthEnd?: () => Promise; + returnOnOAuth?: boolean; +} + export interface ToolDiscoveryOptions { serverName: string; user?: IUser; diff --git a/packages/api/src/mcp/utils.ts b/packages/api/src/mcp/utils.ts index 653a96d5bd..9a5e428207 100644 --- a/packages/api/src/mcp/utils.ts +++ b/packages/api/src/mcp/utils.ts @@ -3,6 +3,13 @@ import type { ParsedServerConfig } from '~/mcp/types'; export const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); +/** Whether a server requires OAuth (has `requiresOAuth` or `oauthMetadata`). */ +export function isOAuthServer( + config: Pick, +): boolean { + return Boolean(config.requiresOAuth || config.oauthMetadata); +} + /** Checks that `customUserVars` is present AND non-empty (guards against truthy `{}`) */ export function hasCustomUserVars(config: Pick): boolean { return !!config.customUserVars && Object.keys(config.customUserVars).length > 0;