diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b7975b12fa..162e02d91e 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -450,7 +450,7 @@ async function getMCPSetupData(userId) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = mcpManager.getOAuthServers(); + const oauthServers = mcpManager.getOAuthServers() || new Set(); return { mcpConfig, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 7b192995e3..8b9f7b675d 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -170,7 +170,7 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => null) }, getUserConnections: jest.fn(() => null), - getOAuthServers: jest.fn(() => new Set()), + getOAuthServers: jest.fn(() => null), }; mockGetMCPManager.mockReturnValue(mockMCPManager); diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index c6bfe77b8f..a8768cf7b0 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -38,7 +38,7 @@ export class MCPManager extends UserConnectionManager { /** Initializes the MCPManager by setting up server registry and app connections */ public async initialize() { await this.serversRegistry.initialize(); - this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs); + this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs!); } /** Retrieves an app-level or user-specific connection based on provided arguments */ @@ -63,23 +63,22 @@ export class MCPManager extends UserConnectionManager { } /** Get servers that require OAuth */ - public getOAuthServers(): Set { - return this.serversRegistry.oauthServers; + public getOAuthServers(): Set | null { + return this.serversRegistry.oauthServers!; } /** Get all servers */ - public getAllServers(): t.MCPServers { - return this.serversRegistry.rawConfigs; + public getAllServers(): t.MCPServers | null { + return this.serversRegistry.rawConfigs!; } /** Returns all available tool functions from app-level connections */ - public getAppToolFunctions(): t.LCAvailableTools { - return this.serversRegistry.toolFunctions; + public getAppToolFunctions(): t.LCAvailableTools | null { + return this.serversRegistry.toolFunctions!; } - /** Returns all available tool functions from all connections available to user */ public async getAllToolFunctions(userId: string): Promise { - const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions(); + const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions() ?? {}; const userConnections = this.getUserConnections(userId); if (!userConnections || userConnections.size === 0) { return allToolFunctions; @@ -129,7 +128,7 @@ export class MCPManager extends UserConnectionManager { * @returns Object mapping server names to their instructions */ public getInstructions(serverNames?: string[]): Record { - const instructions = this.serversRegistry.serverInstructions; + const instructions = this.serversRegistry.serverInstructions!; if (!serverNames) return instructions; return pick(instructions, serverNames); } diff --git a/packages/api/src/mcp/MCPServersRegistry.ts b/packages/api/src/mcp/MCPServersRegistry.ts index cf75ddbb94..905a62bef8 100644 --- a/packages/api/src/mcp/MCPServersRegistry.ts +++ b/packages/api/src/mcp/MCPServersRegistry.ts @@ -1,3 +1,5 @@ +import pick from 'lodash/pick'; +import pickBy from 'lodash/pickBy'; import mapValues from 'lodash/mapValues'; import { logger } from '@librechat/data-schemas'; import { Constants } from 'librechat-data-provider'; @@ -9,14 +11,6 @@ import { detectOAuthRequirement } from '~/mcp/oauth'; import { sanitizeUrlForLogging } from '~/mcp/utils'; import { processMCPEnv, isEnabled } from '~/utils'; -const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000; - -function getMCPInitTimeout(): number { - return process.env.MCP_INIT_TIMEOUT_MS != null - ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) - : DEFAULT_MCP_INIT_TIMEOUT_MS; -} - /** * Manages MCP server configurations and metadata discovery. * Fetches server capabilities, OAuth requirements, and tool definitions for registry. @@ -26,21 +20,19 @@ function getMCPInitTimeout(): number { export class MCPServersRegistry { private initialized: boolean = false; private connections: ConnectionsRepository; - private initTimeoutMs: number; public readonly rawConfigs: t.MCPServers; public readonly parsedConfigs: Record; - public oauthServers: Set = new Set(); - public serverInstructions: Record = {}; - public toolFunctions: t.LCAvailableTools = {}; - public appServerConfigs: t.MCPServers = {}; + public oauthServers: Set | null = null; + public serverInstructions: Record | null = null; + public toolFunctions: t.LCAvailableTools | null = null; + public appServerConfigs: t.MCPServers | null = null; constructor(configs: t.MCPServers) { this.rawConfigs = configs; this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con })); this.connections = new ConnectionsRepository(configs); - this.initTimeoutMs = getMCPInitTimeout(); } /** Initializes all startup-enabled servers by gathering their metadata asynchronously */ @@ -50,43 +42,21 @@ export class MCPServersRegistry { const serverNames = Object.keys(this.parsedConfigs); - await Promise.allSettled( - serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)), - ); + await Promise.allSettled(serverNames.map((serverName) => this.gatherServerInfo(serverName))); + + this.setOAuthServers(); + this.setServerInstructions(); + this.setAppServerConfigs(); + await this.setAppToolFunctions(); + + this.connections.disconnectAll(); } - /** Wraps server initialization with a timeout to prevent hanging */ - private async initializeServerWithTimeout(serverName: string): Promise { - let timeoutId: NodeJS.Timeout | null = null; - - try { - await Promise.race([ - this.initializeServer(serverName), - new Promise((_, reject) => { - timeoutId = setTimeout(() => { - reject(new Error('Server initialization timed out')); - }, this.initTimeoutMs); - }), - ]); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error); - throw error; - } finally { - if (timeoutId != null) { - clearTimeout(timeoutId); - } - } - } - - /** Initializes a single server with all its metadata and adds it to appropriate collections */ - private async initializeServer(serverName: string): Promise { - logger.info(`${this.prefix(serverName)} Initializing server`); - const start = Date.now(); - - const config = this.parsedConfigs[serverName]; - + /** Fetches all metadata for a single server in parallel */ + private async gatherServerInfo(serverName: string): Promise { try { await this.fetchOAuthRequirement(serverName); + const config = this.parsedConfigs[serverName]; if (config.startup !== false && !config.requiresOAuth) { await Promise.allSettled([ @@ -103,39 +73,49 @@ export class MCPServersRegistry { } catch (error) { logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error); } + } - // Add to OAuth servers if needed - if (config.requiresOAuth) { - this.oauthServers.add(serverName); + /** Sets app-level server configs (startup enabled, non-OAuth servers) */ + private setAppServerConfigs(): void { + const appServers = Object.keys( + pickBy( + this.parsedConfigs, + (config) => config.startup !== false && config.requiresOAuth === false, + ), + ); + this.appServerConfigs = pick(this.rawConfigs, appServers); + } + + /** Creates set of server names that require OAuth authentication */ + private setOAuthServers(): Set { + if (this.oauthServers) return this.oauthServers; + this.oauthServers = new Set( + Object.keys(pickBy(this.parsedConfigs, (config) => config.requiresOAuth)), + ); + return this.oauthServers; + } + + /** Collects server instructions from all configured servers */ + private setServerInstructions(): void { + this.serverInstructions = mapValues( + pickBy(this.parsedConfigs, (config) => config.serverInstructions), + (config) => config.serverInstructions as string, + ); + } + + /** Builds registry of all available tool functions from loaded connections */ + private async setAppToolFunctions(): Promise { + const connections = (await this.connections.getLoaded()).entries(); + const allToolFunctions: t.LCAvailableTools = {}; + for (const [serverName, conn] of connections) { + try { + const toolFunctions = await this.getToolFunctions(serverName, conn); + Object.assign(allToolFunctions, toolFunctions); + } catch (error) { + logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); + } } - - // Add server instructions if available - if (config.serverInstructions != null) { - this.serverInstructions[serverName] = config.serverInstructions as string; - } - - // Add to app server configs if eligible (startup enabled, non-OAuth servers) - if (config.startup !== false && config.requiresOAuth === false) { - this.appServerConfigs[serverName] = this.rawConfigs[serverName]; - } - - // Fetch tool functions for this server if a connection was established - try { - const conn = await this.connections.get(serverName); - const toolFunctions = await this.getToolFunctions(serverName, conn); - Object.assign(this.toolFunctions, toolFunctions); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); - } - - // Disconnect this server's connection after initialization - try { - await this.connections.disconnect(serverName); - } catch (disconnectError) { - logger.debug(`${this.prefix(serverName)} Failed to disconnect:`, disconnectError); - } - - logger.info(`${this.prefix(serverName)} Initialized server in ${Date.now() - start}ms`); + this.toolFunctions = allToolFunctions; } /** Converts server tools to LibreChat-compatible tool functions format */ diff --git a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts index 7515a47673..9a276e3713 100644 --- a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts +++ b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts @@ -113,7 +113,6 @@ describe('MCPServersRegistry - Initialize Function', () => { get: jest.fn(), getLoaded: jest.fn(), disconnectAll: jest.fn(), - disconnect: jest.fn().mockResolvedValue(undefined), } as unknown as jest.Mocked; mockConnectionsRepo.get.mockImplementation((serverName: string) => { @@ -161,7 +160,6 @@ describe('MCPServersRegistry - Initialize Function', () => { }); afterEach(() => { - delete process.env.MCP_INIT_TIMEOUT_MS; jest.clearAllMocks(); }); @@ -181,14 +179,15 @@ describe('MCPServersRegistry - Initialize Function', () => { const registry = new MCPServersRegistry(rawConfigs); // Verify initial state - expect(registry.oauthServers.size).toBe(0); - expect(registry.serverInstructions).toEqual({}); - expect(registry.toolFunctions).toEqual({}); - expect(registry.appServerConfigs).toEqual({}); + expect(registry.oauthServers).toBeNull(); + expect(registry.serverInstructions).toBeNull(); + expect(registry.toolFunctions).toBeNull(); + expect(registry.appServerConfigs).toBeNull(); await registry.initialize(); // Test oauthServers Set + expect(registry.oauthServers).toBeInstanceOf(Set); expect(registry.oauthServers).toEqual( new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']), ); @@ -229,49 +228,18 @@ describe('MCPServersRegistry - Initialize Function', () => { expect(registry.toolFunctions).toEqual(expectedToolFunctions); }); - it('should handle errors gracefully and continue initialization of other servers', async () => { + it('should handle errors gracefully and continue initialization', async () => { const registry = new MCPServersRegistry(rawConfigs); - // Make one specific server throw an error during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - return Promise.reject(new Error('OAuth detection failed')); - } - // Return normal responses for other servers - const oauthResults: Record = { - 'https://api.disabled.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - 'https://api.public.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - }; - return Promise.resolve( - oauthResults[url] ?? { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - ); - }); + // Make one server throw an error + mockDetectOAuthRequirement.mockRejectedValueOnce(new Error('OAuth detection failed')); await registry.initialize(); - // Should still initialize successfully for other servers + // Should still initialize successfully expect(registry.oauthServers).toBeInstanceOf(Set); expect(registry.toolFunctions).toBeDefined(); - // The failed server should not be in oauthServers (since it failed OAuth detection) - expect(registry.oauthServers.has('oauth_server')).toBe(false); - - // But other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - // Error should be logged as a warning at the higher level expect(mockLogger.warn).toHaveBeenCalledWith( expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'), @@ -279,15 +247,12 @@ describe('MCPServersRegistry - Initialize Function', () => { ); }); - it('should disconnect individual connections after each server initialization', async () => { + it('should disconnect all connections after initialization', async () => { const registry = new MCPServersRegistry(rawConfigs); await registry.initialize(); - // Verify disconnect was called for each server during initialization - // All servers attempt to connect during initialization for metadata gathering - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); + expect(mockConnectionsRepo.disconnectAll).toHaveBeenCalledTimes(1); }); it('should log configuration updates for each startup-enabled server', async () => { @@ -392,125 +357,5 @@ describe('MCPServersRegistry - Initialize Function', () => { // Verify getInstructions was called for both "true" cases expect(mockClient.getInstructions).toHaveBeenCalledTimes(2); }); - - it('should use Promise.allSettled for individual server initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Spy on Promise.allSettled to verify it's being used - const allSettledSpy = jest.spyOn(Promise, 'allSettled'); - - await registry.initialize(); - - // Verify Promise.allSettled was called with an array of server initialization promises - expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); - - // Verify it was called with the correct number of server promises - const serverNames = Object.keys(rawConfigs); - expect(allSettledSpy).toHaveBeenCalledWith( - expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))), - ); - - allSettledSpy.mockRestore(); - }); - - it('should isolate server failures and not affect other servers', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make multiple servers fail in different ways - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - // First server fails - throw new Error('Connection failed for stdio_server'); - } - if (serverName === 'websocket_server') { - // Second server fails - throw new Error('Connection failed for websocket_server'); - } - // Other servers succeed - const connection = mockConnections.get(serverName); - if (!connection) { - throw new Error(`Connection not found for server: ${serverName}`); - } - return Promise.resolve(connection); - }); - - await registry.initialize(); - - // Despite failures, initialization should complete - expect(registry.oauthServers).toBeInstanceOf(Set); - expect(registry.toolFunctions).toBeDefined(); - - // Successful servers should still be processed - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - - // Failed servers should not crash the whole initialization - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][stdio_server] Error fetching tool functions:'), - expect.any(Error), - ); - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][websocket_server] Error fetching tool functions:'), - expect.any(Error), - ); - }); - - it('should properly clean up connections even when some servers fail', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make some connections fail during disconnect - mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - return Promise.reject(new Error('Disconnect failed')); - } - return Promise.resolve(); - }); - - await registry.initialize(); - - // Should still attempt to disconnect all servers during initialization - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); - - // Failed disconnects should be logged but not crash initialization - expect(mockLogger.debug).toHaveBeenCalledWith( - expect.stringContaining('[MCP][stdio_server] Failed to disconnect:'), - expect.any(Error), - ); - }); - - it('should timeout individual server initialization after configured timeout', async () => { - const timeout = 2000; - // Create registry with a short timeout for testing - process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`; - - const registry = new MCPServersRegistry(rawConfigs); - - // Make one server hang indefinitely during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - // Slow init - return new Promise((res) => setTimeout(res, timeout * 2)); - } - // Return normal responses for other servers - return Promise.resolve({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - }); - - const start = Date.now(); - await registry.initialize(); - const duration = Date.now() - start; - - // Should complete within reasonable time despite one server hanging - // Allow some buffer for test execution overhead - expect(duration).toBeLessThan(timeout * 1.5); - - // The timeout should prevent the hanging server from blocking initialization - // Other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - }, 10_000); // 10 second Jest timeout }); }); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index 09abb2b048..9e84ef1483 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -72,7 +72,7 @@ export class OAuthReconnectionManager { // 1. derive the servers to reconnect const serversToReconnect = []; - for (const serverName of this.mcpManager.getOAuthServers()) { + for (const serverName of this.mcpManager.getOAuthServers() ?? []) { const canReconnect = await this.canReconnect(userId, serverName); if (canReconnect) { serversToReconnect.push(serverName);