diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 791f824dbf..162e02d91e 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -442,10 +442,10 @@ async function getMCPSetupData(userId) { } const mcpManager = getMCPManager(userId); - /** @type {ReturnType} */ + /** @type {Map} */ let appConnections = new Map(); try { - appConnections = (await mcpManager.getAllConnections()) || new Map(); + appConnections = (await mcpManager.appConnections?.getAll()) || new Map(); } catch (error) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 9773d58745..8b9f7b675d 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -123,7 +123,7 @@ describe('tests for the new helper functions used by the MCP connection status e beforeEach(() => { mockGetAppConfig = require('./Config').getAppConfig; mockGetMCPManager.mockReturnValue({ - getAllConnections: jest.fn(() => new Map()), + appConnections: { getAll: jest.fn(() => new Map()) }, getUserConnections: jest.fn(() => new Map()), getOAuthServers: jest.fn(() => new Set()), }); @@ -137,7 +137,7 @@ describe('tests for the new helper functions used by the MCP connection status e const mockOAuthServers = new Set(['server2']); const mockMCPManager = { - getAllConnections: jest.fn(() => mockAppConnections), + appConnections: { getAll: jest.fn(() => mockAppConnections) }, getUserConnections: jest.fn(() => mockUserConnections), getOAuthServers: jest.fn(() => mockOAuthServers), }; @@ -147,7 +147,7 @@ describe('tests for the new helper functions used by the MCP connection status e expect(mockGetAppConfig).toHaveBeenCalled(); expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); - expect(mockMCPManager.getAllConnections).toHaveBeenCalled(); + expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.getOAuthServers).toHaveBeenCalled(); @@ -168,7 +168,7 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetAppConfig.mockResolvedValue({ mcpConfig: mockConfig.mcpServers }); const mockMCPManager = { - getAllConnections: jest.fn(() => null), + appConnections: { getAll: jest.fn(() => null) }, getUserConnections: jest.fn(() => null), getOAuthServers: jest.fn(() => null), }; diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 2669ba4658..e6d293800d 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -29,7 +29,7 @@ async function reinitMCPServer({ flowManager: _flowManager, }) { /** @type {MCPConnection | null} */ - let userConnection = null; + let connection = null; /** @type {LCAvailableTools | null} */ let availableTools = null; /** @type {ReturnType | null} */ @@ -50,7 +50,7 @@ async function reinitMCPServer({ }); try { - userConnection = await mcpManager.getUserConnection({ + connection = await mcpManager.getConnection({ user, signal, forceNew, @@ -70,7 +70,7 @@ async function reinitMCPServer({ logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); } catch (err) { - logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`); + logger.info(`[MCP Reinitialize] getConnection threw error: ${err.message}`); logger.info( `[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`, ); @@ -95,8 +95,8 @@ async function reinitMCPServer({ } } - if (userConnection && !oauthRequired) { - tools = await userConnection.fetchTools(); + if (connection && !oauthRequired) { + tools = await connection.fetchTools(); availableTools = await updateMCPServerTools({ serverName, tools, @@ -111,7 +111,7 @@ async function reinitMCPServer({ if (oauthRequired) { return `MCP server '${serverName}' ready for OAuth authentication`; } - if (userConnection) { + if (connection) { return `MCP server '${serverName}' reinitialized successfully`; } return `Failed to reinitialize MCP server '${serverName}'`; @@ -119,7 +119,7 @@ async function reinitMCPServer({ const result = { availableTools, - success: Boolean((userConnection && !oauthRequired) || (oauthRequired && oauthUrl)), + success: Boolean((connection && !oauthRequired) || (oauthRequired && oauthUrl)), message: getResponseMessage(), oauthRequired, serverName, diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index d25f652b40..9d3145c632 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -20,8 +20,6 @@ import { processMCPEnv } from '~/utils/env'; */ export class MCPManager extends UserConnectionManager { private static instance: MCPManager | null; - // Connections shared by all users. - private appConnections: ConnectionsRepository | null = null; /** Creates and initializes the singleton MCPManager instance */ public static async createInstance(configs: t.MCPServers): Promise { @@ -43,9 +41,25 @@ export class MCPManager extends UserConnectionManager { this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!); } - /** Returns all app-level connections */ - public async getAllConnections(): Promise | null> { - return this.appConnections!.getAll(); + /** Retrieves an app-level or user-specific connection based on provided arguments */ + public async getConnection( + args: { + serverName: string; + user?: TUser; + forceNew?: boolean; + flowManager?: FlowStateManager; + } & Omit, + ): Promise { + if (this.appConnections!.has(args.serverName)) { + return this.appConnections!.get(args.serverName); + } else if (args.user?.id) { + return this.getUserConnection(args as Parameters[0]); + } else { + throw new McpError( + ErrorCode.InvalidRequest, + `No connection found for server ${args.serverName}`, + ); + } } /** Get servers that require OAuth */ @@ -180,30 +194,19 @@ Please follow these instructions when using tools from the respective MCP server const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`; try { - if (!this.appConnections?.has(serverName) && userId && user) { - this.updateUserLastActivity(userId); - /** Get or create user-specific connection */ - connection = await this.getUserConnection({ - user, - serverName, - flowManager, - tokenMethods, - oauthStart, - oauthEnd, - signal: options?.signal, - customUserVars, - requestBody, - }); - } else { - /** App-level connection */ - connection = await this.appConnections!.get(serverName); - if (!connection) { - throw new McpError( - ErrorCode.InvalidRequest, - `${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`, - ); - } - } + if (userId && user) this.updateUserLastActivity(userId); + + connection = await this.getConnection({ + serverName, + user, + flowManager, + tokenMethods, + oauthStart, + oauthEnd, + signal: options?.signal, + customUserVars, + requestBody, + }); if (!(await connection.isConnected())) { /** May happen if getUserConnection failed silently or app connection dropped */ diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 92d6e012e7..7f5862b2a8 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -4,6 +4,7 @@ import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; import { MCPConnection } from './connection'; import type * as t from './types'; +import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; /** * Abstract base class for managing user-specific MCP connections with lifecycle management. @@ -14,6 +15,9 @@ import type * as t from './types'; */ export abstract class UserConnectionManager { protected readonly serversRegistry: MCPServersRegistry; + // Connections shared by all users. + public appConnections: ConnectionsRepository | null = null; + // Connections per userId -> serverName -> connection protected userConnections: Map> = new Map(); /** Last activity timestamp for users (not per server) */ protected userLastActivity: Map = new Map(); @@ -60,6 +64,13 @@ export abstract class UserConnectionManager { throw new McpError(ErrorCode.InvalidRequest, `[MCP] User object missing id property`); } + if (this.appConnections!.has(serverName)) { + throw new McpError( + ErrorCode.InvalidRequest, + `[MCP][User: ${userId}] Trying to create user-specific connection for app-level server "${serverName}"`, + ); + } + const userServerMap = this.userConnections.get(userId); let connection = forceNew ? undefined : userServerMap?.get(serverName); const now = Date.now();