diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 162e02d91e..b7975b12fa 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -450,7 +450,7 @@ async function getMCPSetupData(userId) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = mcpManager.getOAuthServers() || new Set(); + const oauthServers = mcpManager.getOAuthServers(); return { mcpConfig, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 8b9f7b675d..7b192995e3 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -170,7 +170,7 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => null) }, getUserConnections: jest.fn(() => null), - getOAuthServers: jest.fn(() => null), + getOAuthServers: jest.fn(() => new Set()), }; mockGetMCPManager.mockReturnValue(mockMCPManager); diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index 9d3145c632..d3966ff008 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -38,7 +38,7 @@ export class MCPManager extends UserConnectionManager { /** Initializes the MCPManager by setting up server registry and app connections */ public async initialize() { await this.serversRegistry.initialize(); - this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!); + this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs); } /** Retrieves an app-level or user-specific connection based on provided arguments */ @@ -63,22 +63,23 @@ export class MCPManager extends UserConnectionManager { } /** Get servers that require OAuth */ - public getOAuthServers(): Set | null { - return this.serversRegistry.oauthServers!; + public getOAuthServers(): Set { + return this.serversRegistry.oauthServers; } /** Get all servers */ - public getAllServers(): t.MCPServers | null { - return this.serversRegistry.rawConfigs!; + public getAllServers(): t.MCPServers { + return this.serversRegistry.rawConfigs; } /** Returns all available tool functions from app-level connections */ - public getAppToolFunctions(): t.LCAvailableTools | null { - return this.serversRegistry.toolFunctions!; + public getAppToolFunctions(): t.LCAvailableTools { + return this.serversRegistry.toolFunctions; } + /** Returns all available tool functions from all connections available to user */ public async getAllToolFunctions(userId: string): Promise { - const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions() ?? {}; + const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions(); const userConnections = this.getUserConnections(userId); if (!userConnections || userConnections.size === 0) { return allToolFunctions; @@ -120,7 +121,7 @@ export class MCPManager extends UserConnectionManager { * @returns Object mapping server names to their instructions */ public getInstructions(serverNames?: string[]): Record { - const instructions = this.serversRegistry.serverInstructions!; + const instructions = this.serversRegistry.serverInstructions; if (!serverNames) return instructions; return pick(instructions, serverNames); } diff --git a/packages/api/src/mcp/MCPServersRegistry.ts b/packages/api/src/mcp/MCPServersRegistry.ts index 905a62bef8..cf75ddbb94 100644 --- a/packages/api/src/mcp/MCPServersRegistry.ts +++ b/packages/api/src/mcp/MCPServersRegistry.ts @@ -1,5 +1,3 @@ -import pick from 'lodash/pick'; -import pickBy from 'lodash/pickBy'; import mapValues from 'lodash/mapValues'; import { logger } from '@librechat/data-schemas'; import { Constants } from 'librechat-data-provider'; @@ -11,6 +9,14 @@ import { detectOAuthRequirement } from '~/mcp/oauth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; import { processMCPEnv, isEnabled } from '~/utils'; +const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000; + +function getMCPInitTimeout(): number { + return process.env.MCP_INIT_TIMEOUT_MS != null + ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) + : DEFAULT_MCP_INIT_TIMEOUT_MS; +} + /** * Manages MCP server configurations and metadata discovery. * Fetches server capabilities, OAuth requirements, and tool definitions for registry. @@ -20,19 +26,21 @@ import { processMCPEnv, isEnabled } from '~/utils'; export class MCPServersRegistry { private initialized: boolean = false; private connections: ConnectionsRepository; + private initTimeoutMs: number; public readonly rawConfigs: t.MCPServers; public readonly parsedConfigs: Record; - public oauthServers: Set | null = null; - public serverInstructions: Record | null = null; - public toolFunctions: t.LCAvailableTools | null = null; - public appServerConfigs: t.MCPServers | null = null; + public oauthServers: Set = new Set(); + public serverInstructions: Record = {}; + public toolFunctions: t.LCAvailableTools = {}; + public appServerConfigs: t.MCPServers = {}; constructor(configs: t.MCPServers) { this.rawConfigs = configs; this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con })); this.connections = new ConnectionsRepository(configs); + this.initTimeoutMs = getMCPInitTimeout(); } /** Initializes all startup-enabled servers by gathering their metadata asynchronously */ @@ -42,21 +50,43 @@ export class MCPServersRegistry { const serverNames = Object.keys(this.parsedConfigs); - await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName))); - - this.setOAuthServers(); - this.setServerInstructions(); - this.setAppServerConfigs(); - await this.setAppToolFunctions(); - - this.connections.disconnectAll(); + await Promise.allSettled( + serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)), + ); } - /** Fetches all metadata for a single server in parallel */ - private async gatherServerInfo(serverName: string): Promise { + /** Wraps server initialization with a timeout to prevent hanging */ + private async initializeServerWithTimeout(serverName: string): Promise { + let timeoutId: NodeJS.Timeout | null = null; + + try { + await Promise.race([ + this.initializeServer(serverName), + new Promise((_, reject) => { + timeoutId = setTimeout(() => { + reject(new Error('Server initialization timed out')); + }, this.initTimeoutMs); + }), + ]); + } catch (error) { + logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error); + throw error; + } finally { + if (timeoutId != null) { + clearTimeout(timeoutId); + } + } + } + + /** Initializes a single server with all its metadata and adds it to appropriate collections */ + private async initializeServer(serverName: string): Promise { + logger.info(`${this.prefix(serverName)} Initializing server`); + const start = Date.now(); + + const config = this.parsedConfigs[serverName]; + try { await this.fetchOAuthRequirement(serverName); - const config = this.parsedConfigs[serverName]; if (config.startup !== false && !config.requiresOAuth) { await Promise.allSettled([ @@ -73,49 +103,39 @@ export class MCPServersRegistry { } catch (error) { logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error); } - } - /** Sets app-level server configs (startup enabled, non-OAuth servers) */ - private setAppServerConfigs(): void { - const appServers = Object.keys( - pickBy( - this.parsedConfigs, - (config) => config.startup !== false && config.requiresOAuth === false, - ), - ); - this.appServerConfigs = pick(this.rawConfigs, appServers); - } - - /** Creates set of server names that require OAuth authentication */ - private setOAuthServers(): Set { - if (this.oauthServers) return this.oauthServers; - this.oauthServers = new Set( - Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)), - ); - return this.oauthServers; - } - - /** Collects server instructions from all configured servers */ - private setServerInstructions(): void { - this.serverInstructions = mapValues( - pickBy(this.parsedConfigs, (config) => config.serverInstructions), - (config) => config.serverInstructions as string, - ); - } - - /** Builds registry of all available tool functions from loaded connections */ - private async setAppToolFunctions(): Promise { - const connections = (await this.connections.getLoaded()).entries(); - const allToolFunctions: t.LCAvailableTools = {}; - for (const [serverName, conn] of connections) { - try { - const toolFunctions = await this.getToolFunctions(serverName, conn); - Object.assign(allToolFunctions, toolFunctions); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); - } + // Add to OAuth servers if needed + if (config.requiresOAuth) { + this.oauthServers.add(serverName); } - this.toolFunctions = allToolFunctions; + + // Add server instructions if available + if (config.serverInstructions != null) { + this.serverInstructions[serverName] = config.serverInstructions as string; + } + + // Add to app server configs if eligible (startup enabled, non-OAuth servers) + if (config.startup !== false && config.requiresOAuth === false) { + this.appServerConfigs[serverName] = this.rawConfigs[serverName]; + } + + // Fetch tool functions for this server if a connection was established + try { + const conn = await this.connections.get(serverName); + const toolFunctions = await this.getToolFunctions(serverName, conn); + Object.assign(this.toolFunctions, toolFunctions); + } catch (error) { + logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); + } + + // Disconnect this server's connection after initialization + try { + await this.connections.disconnect(serverName); + } catch (disconnectError) { + logger.debug(`${this.prefix(serverName)} Failed to disconnect:`, disconnectError); + } + + logger.info(`${this.prefix(serverName)} Initialized server in ${Date.now() - start}ms`); } /** Converts server tools to LibreChat-compatible tool functions format */ diff --git a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts index 9a276e3713..7515a47673 100644 --- a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts @@ -113,6 +113,7 @@ describe('MCPServersRegistry - Initialize Function', () => { get: jest.fn(), getLoaded: jest.fn(), disconnectAll: jest.fn(), + disconnect: jest.fn().mockResolvedValue(undefined), } as unknown as jest.Mocked; mockConnectionsRepo.get.mockImplementation((serverName: string) => { @@ -160,6 +161,7 @@ describe('MCPServersRegistry - Initialize Function', () => { }); afterEach(() => { + delete process.env.MCP_INIT_TIMEOUT_MS; jest.clearAllMocks(); }); @@ -179,15 +181,14 @@ describe('MCPServersRegistry - Initialize Function', () => { const registry = new MCPServersRegistry(rawConfigs); // Verify initial state - expect(registry.oauthServers).toBeNull(); - expect(registry.serverInstructions).toBeNull(); - expect(registry.toolFunctions).toBeNull(); - expect(registry.appServerConfigs).toBeNull(); + expect(registry.oauthServers.size).toBe(0); + expect(registry.serverInstructions).toEqual({}); + expect(registry.toolFunctions).toEqual({}); + expect(registry.appServerConfigs).toEqual({}); await registry.initialize(); // Test oauthServers Set - expect(registry.oauthServers).toBeInstanceOf(Set); expect(registry.oauthServers).toEqual( new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']), ); @@ -228,18 +229,49 @@ describe('MCPServersRegistry - Initialize Function', () => { expect(registry.toolFunctions).toEqual(expectedToolFunctions); }); - it('should handle errors gracefully and continue initialization', async () => { + it('should handle errors gracefully and continue initialization of other servers', async () => { const registry = new MCPServersRegistry(rawConfigs); - // Make one server throw an error - mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed')); + // Make one specific server throw an error during OAuth detection + mockDetectOAuthRequirement.mockImplementation((url: string) => { + if (url === 'https://api.github.com/mcp') { + return Promise.reject(new Error('OAuth detection failed')); + } + // Return normal responses for other servers + const oauthResults: Record = { + 'https://api.disabled.com/mcp': { + requiresOAuth: false, + method: 'no-metadata-found', + metadata: null, + }, + 'https://api.public.com/mcp': { + requiresOAuth: false, + method: 'no-metadata-found', + metadata: null, + }, + }; + return Promise.resolve( + oauthResults[url] ?? { + requiresOAuth: false, + method: 'no-metadata-found', + metadata: null, + }, + ); + }); await registry.initialize(); - // Should still initialize successfully + // Should still initialize successfully for other servers expect(registry.oauthServers).toBeInstanceOf(Set); expect(registry.toolFunctions).toBeDefined(); + // The failed server should not be in oauthServers (since it failed OAuth detection) + expect(registry.oauthServers.has('oauth_server')).toBe(false); + + // But other servers should still be processed successfully + expect(registry.appServerConfigs).toHaveProperty('stdio_server'); + expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); + // Error should be logged as a warning at the higher level expect(mockLogger.warn).toHaveBeenCalledWith( expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'), @@ -247,12 +279,15 @@ describe('MCPServersRegistry - Initialize Function', () => { ); }); - it('should disconnect all connections after initialization', async () => { + it('should disconnect individual connections after each server initialization', async () => { const registry = new MCPServersRegistry(rawConfigs); await registry.initialize(); - expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1); + // Verify disconnect was called for each server during initialization + // All servers attempt to connect during initialization for metadata gathering + const serverNames = Object.keys(rawConfigs); + expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); }); it('should log configuration updates for each startup-enabled server', async () => { @@ -357,5 +392,125 @@ describe('MCPServersRegistry - Initialize Function', () => { // Verify getInstructions was called for both "true" cases expect(mockClient.getInstructions).toHaveBeenCalledTimes(2); }); + + it('should use Promise.allSettled for individual server initialization', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + // Spy on Promise.allSettled to verify it's being used + const allSettledSpy = jest.spyOn(Promise, 'allSettled'); + + await registry.initialize(); + + // Verify Promise.allSettled was called with an array of server initialization promises + expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); + + // Verify it was called with the correct number of server promises + const serverNames = Object.keys(rawConfigs); + expect(allSettledSpy).toHaveBeenCalledWith( + expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))), + ); + + allSettledSpy.mockRestore(); + }); + + it('should isolate server failures and not affect other servers', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + // Make multiple servers fail in different ways + mockConnectionsRepo.get.mockImplementation((serverName: string) => { + if (serverName === 'stdio_server') { + // First server fails + throw new Error('Connection failed for stdio_server'); + } + if (serverName === 'websocket_server') { + // Second server fails + throw new Error('Connection failed for websocket_server'); + } + // Other servers succeed + const connection = mockConnections.get(serverName); + if (!connection) { + throw new Error(`Connection not found for server: ${serverName}`); + } + return Promise.resolve(connection); + }); + + await registry.initialize(); + + // Despite failures, initialization should complete + expect(registry.oauthServers).toBeInstanceOf(Set); + expect(registry.toolFunctions).toBeDefined(); + + // Successful servers should still be processed + expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); + + // Failed servers should not crash the whole initialization + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('[MCP][stdio_server] Error fetching tool functions:'), + expect.any(Error), + ); + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('[MCP][websocket_server] Error fetching tool functions:'), + expect.any(Error), + ); + }); + + it('should properly clean up connections even when some servers fail', async () => { + const registry = new MCPServersRegistry(rawConfigs); + + // Make some connections fail during disconnect + mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => { + if (serverName === 'stdio_server') { + return Promise.reject(new Error('Disconnect failed')); + } + return Promise.resolve(); + }); + + await registry.initialize(); + + // Should still attempt to disconnect all servers during initialization + const serverNames = Object.keys(rawConfigs); + expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); + + // Failed disconnects should be logged but not crash initialization + expect(mockLogger.debug).toHaveBeenCalledWith( + expect.stringContaining('[MCP][stdio_server] Failed to disconnect:'), + expect.any(Error), + ); + }); + + it('should timeout individual server initialization after configured timeout', async () => { + const timeout = 2000; + // Create registry with a short timeout for testing + process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`; + + const registry = new MCPServersRegistry(rawConfigs); + + // Make one server hang indefinitely during OAuth detection + mockDetectOAuthRequirement.mockImplementation((url: string) => { + if (url === 'https://api.github.com/mcp') { + // Slow init + return new Promise((res) => setTimeout(res, timeout * 2)); + } + // Return normal responses for other servers + return Promise.resolve({ + requiresOAuth: false, + method: 'no-metadata-found', + metadata: null, + }); + }); + + const start = Date.now(); + await registry.initialize(); + const duration = Date.now() - start; + + // Should complete within reasonable time despite one server hanging + // Allow some buffer for test execution overhead + expect(duration).toBeLessThan(timeout * 1.5); + + // The timeout should prevent the hanging server from blocking initialization + // Other servers should still be processed successfully + expect(registry.appServerConfigs).toHaveProperty('stdio_server'); + expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); + }, 10_000); // 10 second Jest timeout }); }); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index 9e84ef1483..09abb2b048 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -72,7 +72,7 @@ export class OAuthReconnectionManager { // 1. derive the servers to reconnect const serversToReconnect = []; - for (const serverName of this.mcpManager.getOAuthServers() ?? []) { + for (const serverName of this.mcpManager.getOAuthServers()) { const canReconnect = await this.canReconnect(userId, serverName); if (canReconnect) { serversToReconnect.push(serverName);